mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 15:24:48 +08:00
refactor: Refactors the retriever function to use async/await (#821)
* refactor: Refactors the retriever function to use async/await
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -94,6 +95,17 @@ class DifyProvider(Retriever):
|
|||||||
|
|
||||||
return list(all_documents.values())
|
return list(all_documents.values())
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: list[Resource] = []
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of query_relevant_documents.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.query_relevant_documents, query, resources
|
||||||
|
)
|
||||||
|
|
||||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -124,6 +136,13 @@ class DifyProvider(Retriever):
|
|||||||
|
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of list_resources.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.list_resources, query)
|
||||||
|
|
||||||
|
|
||||||
def parse_uri(uri: str) -> tuple[str, str]:
|
def parse_uri(uri: str) -> tuple[str, str]:
|
||||||
parsed = urlparse(uri)
|
parsed = urlparse(uri)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@@ -514,6 +515,13 @@ class MilvusRetriever(Retriever):
|
|||||||
return self._list_local_markdown_resources()
|
return self._list_local_markdown_resources()
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of list_resources.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.list_resources, query)
|
||||||
|
|
||||||
def _list_local_markdown_resources(self) -> List[Resource]:
|
def _list_local_markdown_resources(self) -> List[Resource]:
|
||||||
"""Return local example markdown files as ``Resource`` objects.
|
"""Return local example markdown files as ``Resource`` objects.
|
||||||
|
|
||||||
@@ -661,6 +669,17 @@ class MilvusRetriever(Retriever):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to query documents from Milvus: {str(e)}")
|
raise RuntimeError(f"Failed to query documents from Milvus: {str(e)}")
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: Optional[List[Resource]] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of query_relevant_documents.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.query_relevant_documents, query, resources
|
||||||
|
)
|
||||||
|
|
||||||
def create_collection(self) -> None:
|
def create_collection(self) -> None:
|
||||||
"""Public hook ensuring collection exists (explicit initialization)."""
|
"""Public hook ensuring collection exists (explicit initialization)."""
|
||||||
if not self.client:
|
if not self.client:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -108,6 +109,17 @@ class MOIProvider(Retriever):
|
|||||||
|
|
||||||
return list(docs.values())
|
return list(docs.values())
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: list[Resource] = []
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of query_relevant_documents.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.query_relevant_documents, query, resources
|
||||||
|
)
|
||||||
|
|
||||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||||
"""
|
"""
|
||||||
List resources from MOI API with optional query filtering and limit support.
|
List resources from MOI API with optional query filtering and limit support.
|
||||||
@@ -144,6 +156,13 @@ class MOIProvider(Retriever):
|
|||||||
|
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of list_resources.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.list_resources, query)
|
||||||
|
|
||||||
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Parse URI to extract dataset ID and document ID.
|
Parse URI to extract dataset ID and document ID.
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@@ -351,6 +352,13 @@ class QdrantProvider(Retriever):
|
|||||||
return self._list_local_markdown_resources()
|
return self._list_local_markdown_resources()
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of list_resources.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.list_resources, query)
|
||||||
|
|
||||||
def _list_local_markdown_resources(self) -> List[Resource]:
|
def _list_local_markdown_resources(self) -> List[Resource]:
|
||||||
current_file = Path(__file__)
|
current_file = Path(__file__)
|
||||||
project_root = current_file.parent.parent.parent
|
project_root = current_file.parent.parent.parent
|
||||||
@@ -419,6 +427,17 @@ class QdrantProvider(Retriever):
|
|||||||
|
|
||||||
return list(documents.values())
|
return list(documents.values())
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: Optional[List[Resource]] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of query_relevant_documents.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.query_relevant_documents, query, resources
|
||||||
|
)
|
||||||
|
|
||||||
def create_collection(self) -> None:
|
def create_collection(self) -> None:
|
||||||
if not self.client:
|
if not self.client:
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -98,6 +99,17 @@ class RAGFlowProvider(Retriever):
|
|||||||
|
|
||||||
return list(docs.values())
|
return list(docs.values())
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: list[Resource] = []
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of query_relevant_documents.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.query_relevant_documents, query, resources
|
||||||
|
)
|
||||||
|
|
||||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -128,6 +140,13 @@ class RAGFlowProvider(Retriever):
|
|||||||
|
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of list_resources.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.list_resources, query)
|
||||||
|
|
||||||
|
|
||||||
def parse_uri(uri: str) -> tuple[str, str]:
|
def parse_uri(uri: str) -> tuple[str, str]:
|
||||||
parsed = urlparse(uri)
|
parsed = urlparse(uri)
|
||||||
|
|||||||
@@ -67,7 +67,18 @@ class Retriever(abc.ABC):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||||
"""
|
"""
|
||||||
List resources from the rag provider.
|
List resources from the rag provider (synchronous version).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||||
|
"""
|
||||||
|
List resources from the rag provider (asynchronous version).
|
||||||
|
|
||||||
|
Implementations should choose between:
|
||||||
|
- Providing native async I/O operations for true non-blocking behavior
|
||||||
|
- Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -76,7 +87,20 @@ class Retriever(abc.ABC):
|
|||||||
self, query: str, resources: list[Resource] = []
|
self, query: str, resources: list[Resource] = []
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Query relevant documents from the resources.
|
Query relevant documents from the resources (synchronous version).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: list[Resource] = []
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Query relevant documents from the resources (asynchronous version).
|
||||||
|
|
||||||
|
Implementations should choose between:
|
||||||
|
- Providing native async I/O operations for true non-blocking behavior
|
||||||
|
- Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
@@ -255,6 +256,17 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
|||||||
|
|
||||||
return list(all_documents.values())
|
return list(all_documents.values())
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(
|
||||||
|
self, query: str, resources: list[Resource] = []
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of query_relevant_documents.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self.query_relevant_documents, query, resources
|
||||||
|
)
|
||||||
|
|
||||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||||
"""
|
"""
|
||||||
List resources (knowledge bases) from the knowledge base service
|
List resources (knowledge bases) from the knowledge base service
|
||||||
@@ -291,6 +303,13 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
|||||||
|
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of list_resources.
|
||||||
|
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(self.list_resources, query)
|
||||||
|
|
||||||
|
|
||||||
def parse_uri(uri: str) -> tuple[str, str]:
|
def parse_uri(uri: str) -> tuple[str, str]:
|
||||||
parsed = urlparse(uri)
|
parsed = urlparse(uri)
|
||||||
|
|||||||
@@ -47,7 +47,15 @@ class RetrieverTool(BaseTool):
|
|||||||
keywords: str,
|
keywords: str,
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
return self._run(keywords, run_manager.get_sync())
|
logger.info(
|
||||||
|
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
|
||||||
|
)
|
||||||
|
documents = await self.retriever.query_relevant_documents_async(
|
||||||
|
keywords, self.resources
|
||||||
|
)
|
||||||
|
if not documents:
|
||||||
|
return "No results found from the local knowledge base."
|
||||||
|
return [doc.to_dict() for doc in documents]
|
||||||
|
|
||||||
|
|
||||||
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
||||||
|
|||||||
@@ -56,18 +56,59 @@ def test_retriever_abstract_methods():
|
|||||||
def list_resources(self, query=None):
|
def list_resources(self, query=None):
|
||||||
return [Resource(uri="uri", title="title")]
|
return [Resource(uri="uri", title="title")]
|
||||||
|
|
||||||
|
async def list_resources_async(self, query=None):
|
||||||
|
return [Resource(uri="uri", title="title")]
|
||||||
|
|
||||||
def query_relevant_documents(self, query, resources=[]):
|
def query_relevant_documents(self, query, resources=[]):
|
||||||
return [Document(id="id", chunks=[])]
|
return [Document(id="id", chunks=[])]
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(self, query, resources=[]):
|
||||||
|
return [Document(id="id", chunks=[])]
|
||||||
|
|
||||||
retriever = DummyRetriever()
|
retriever = DummyRetriever()
|
||||||
|
# Test synchronous methods
|
||||||
resources = retriever.list_resources()
|
resources = retriever.list_resources()
|
||||||
assert isinstance(resources, list)
|
assert isinstance(resources, list)
|
||||||
assert isinstance(resources[0], Resource)
|
assert isinstance(resources[0], Resource)
|
||||||
|
assert resources[0].uri == "uri"
|
||||||
|
|
||||||
docs = retriever.query_relevant_documents("query", resources)
|
docs = retriever.query_relevant_documents("query", resources)
|
||||||
assert isinstance(docs, list)
|
assert isinstance(docs, list)
|
||||||
assert isinstance(docs[0], Document)
|
assert isinstance(docs[0], Document)
|
||||||
|
assert docs[0].id == "id"
|
||||||
|
|
||||||
|
|
||||||
def test_retriever_cannot_instantiate():
|
def test_retriever_cannot_instantiate():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
Retriever()
|
Retriever()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retriever_async_methods():
|
||||||
|
"""Test that async methods work correctly in DummyRetriever."""
|
||||||
|
class DummyRetriever(Retriever):
|
||||||
|
def list_resources(self, query=None):
|
||||||
|
return [Resource(uri="uri", title="title")]
|
||||||
|
|
||||||
|
async def list_resources_async(self, query=None):
|
||||||
|
return [Resource(uri="uri_async", title="title_async")]
|
||||||
|
|
||||||
|
def query_relevant_documents(self, query, resources=[]):
|
||||||
|
return [Document(id="id", chunks=[])]
|
||||||
|
|
||||||
|
async def query_relevant_documents_async(self, query, resources=[]):
|
||||||
|
return [Document(id="id_async", chunks=[])]
|
||||||
|
|
||||||
|
retriever = DummyRetriever()
|
||||||
|
|
||||||
|
# Test async list_resources
|
||||||
|
resources = await retriever.list_resources_async()
|
||||||
|
assert isinstance(resources, list)
|
||||||
|
assert isinstance(resources[0], Resource)
|
||||||
|
assert resources[0].uri == "uri_async"
|
||||||
|
|
||||||
|
# Test async query_relevant_documents
|
||||||
|
docs = await retriever.query_relevant_documents_async("query", resources)
|
||||||
|
assert isinstance(docs, list)
|
||||||
|
assert isinstance(docs[0], Document)
|
||||||
|
assert docs[0].id == "id_async"
|
||||||
|
|||||||
@@ -66,18 +66,20 @@ async def test_retriever_tool_arun():
|
|||||||
mock_retriever = Mock(spec=Retriever)
|
mock_retriever = Mock(spec=Retriever)
|
||||||
chunk = Chunk(content="async content", similarity=0.8)
|
chunk = Chunk(content="async content", similarity=0.8)
|
||||||
doc = Document(id="doc2", chunks=[chunk])
|
doc = Document(id="doc2", chunks=[chunk])
|
||||||
mock_retriever.query_relevant_documents.return_value = [doc]
|
|
||||||
|
# Mock the async method
|
||||||
|
async def mock_async_query(*args, **kwargs):
|
||||||
|
return [doc]
|
||||||
|
|
||||||
|
mock_retriever.query_relevant_documents_async = mock_async_query
|
||||||
|
|
||||||
resources = [Resource(uri="test://uri", title="Test")]
|
resources = [Resource(uri="test://uri", title="Test")]
|
||||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||||
|
|
||||||
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
|
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
|
||||||
mock_sync_manager = Mock(spec=CallbackManagerForToolRun)
|
|
||||||
mock_run_manager.get_sync.return_value = mock_sync_manager
|
|
||||||
|
|
||||||
result = await tool._arun("async keywords", mock_run_manager)
|
result = await tool._arun("async keywords", mock_run_manager)
|
||||||
|
|
||||||
mock_run_manager.get_sync.assert_called_once()
|
|
||||||
assert isinstance(result, list)
|
assert isinstance(result, list)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0] == doc.to_dict()
|
assert result[0] == doc.to_dict()
|
||||||
|
|||||||
Reference in New Issue
Block a user