refactor: Refactors the retriever function to use async/await (#821)

* refactor: Refactors the retriever function to use async/await
This commit is contained in:
Xun
2026-01-20 19:56:26 +08:00
committed by GitHub
parent 2ed0eeb107
commit 0e64c52975
10 changed files with 196 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -94,6 +95,17 @@ class DifyProvider(Retriever):
return list(all_documents.values()) return list(all_documents.values())
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def list_resources(self, query: str | None = None) -> list[Resource]: def list_resources(self, query: str | None = None) -> list[Resource]:
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
@@ -124,6 +136,13 @@ class DifyProvider(Retriever):
return resources return resources
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def parse_uri(uri: str) -> tuple[str, str]: def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri) parsed = urlparse(uri)

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import hashlib import hashlib
import logging import logging
import re import re
@@ -514,6 +515,13 @@ class MilvusRetriever(Retriever):
return self._list_local_markdown_resources() return self._list_local_markdown_resources()
return resources return resources
async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def _list_local_markdown_resources(self) -> List[Resource]: def _list_local_markdown_resources(self) -> List[Resource]:
"""Return local example markdown files as ``Resource`` objects. """Return local example markdown files as ``Resource`` objects.
@@ -661,6 +669,17 @@ class MilvusRetriever(Retriever):
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to query documents from Milvus: {str(e)}") raise RuntimeError(f"Failed to query documents from Milvus: {str(e)}")
async def query_relevant_documents_async(
self, query: str, resources: Optional[List[Resource]] = None
) -> List[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def create_collection(self) -> None: def create_collection(self) -> None:
"""Public hook ensuring collection exists (explicit initialization).""" """Public hook ensuring collection exists (explicit initialization)."""
if not self.client: if not self.client:

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -108,6 +109,17 @@ class MOIProvider(Retriever):
return list(docs.values()) return list(docs.values())
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def list_resources(self, query: str | None = None) -> list[Resource]: def list_resources(self, query: str | None = None) -> list[Resource]:
""" """
List resources from MOI API with optional query filtering and limit support. List resources from MOI API with optional query filtering and limit support.
@@ -144,6 +156,13 @@ class MOIProvider(Retriever):
return resources return resources
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def _parse_uri(self, uri: str) -> tuple[str, str]: def _parse_uri(self, uri: str) -> tuple[str, str]:
""" """
Parse URI to extract dataset ID and document ID. Parse URI to extract dataset ID and document ID.

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import hashlib import hashlib
import logging import logging
import uuid import uuid
@@ -351,6 +352,13 @@ class QdrantProvider(Retriever):
return self._list_local_markdown_resources() return self._list_local_markdown_resources()
return resources return resources
async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def _list_local_markdown_resources(self) -> List[Resource]: def _list_local_markdown_resources(self) -> List[Resource]:
current_file = Path(__file__) current_file = Path(__file__)
project_root = current_file.parent.parent.parent project_root = current_file.parent.parent.parent
@@ -419,6 +427,17 @@ class QdrantProvider(Retriever):
return list(documents.values()) return list(documents.values())
async def query_relevant_documents_async(
self, query: str, resources: Optional[List[Resource]] = None
) -> List[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def create_collection(self) -> None: def create_collection(self) -> None:
if not self.client: if not self.client:
self._connect() self._connect()

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import os import os
from typing import List, Optional from typing import List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -98,6 +99,17 @@ class RAGFlowProvider(Retriever):
return list(docs.values()) return list(docs.values())
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def list_resources(self, query: str | None = None) -> list[Resource]: def list_resources(self, query: str | None = None) -> list[Resource]:
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
@@ -128,6 +140,13 @@ class RAGFlowProvider(Retriever):
return resources return resources
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def parse_uri(uri: str) -> tuple[str, str]: def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri) parsed = urlparse(uri)

View File

@@ -67,7 +67,18 @@ class Retriever(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def list_resources(self, query: str | None = None) -> list[Resource]: def list_resources(self, query: str | None = None) -> list[Resource]:
""" """
List resources from the rag provider. List resources from the rag provider (synchronous version).
"""
pass
@abc.abstractmethod
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
List resources from the rag provider (asynchronous version).
Implementations should choose between:
- Providing native async I/O operations for true non-blocking behavior
- Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available
""" """
pass pass
@@ -76,7 +87,20 @@ class Retriever(abc.ABC):
self, query: str, resources: list[Resource] = [] self, query: str, resources: list[Resource] = []
) -> list[Document]: ) -> list[Document]:
""" """
Query relevant documents from the resources. Query relevant documents from the resources (synchronous version).
"""
pass
@abc.abstractmethod
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Query relevant documents from the resources (asynchronous version).
Implementations should choose between:
- Providing native async I/O operations for true non-blocking behavior
- Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available
""" """
pass pass

View File

@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import hashlib import hashlib
import hmac import hmac
import json import json
@@ -255,6 +256,17 @@ class VikingDBKnowledgeBaseProvider(Retriever):
return list(all_documents.values()) return list(all_documents.values())
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def list_resources(self, query: str | None = None) -> list[Resource]: def list_resources(self, query: str | None = None) -> list[Resource]:
""" """
List resources (knowledge bases) from the knowledge base service List resources (knowledge bases) from the knowledge base service
@@ -291,6 +303,13 @@ class VikingDBKnowledgeBaseProvider(Retriever):
return resources return resources
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def parse_uri(uri: str) -> tuple[str, str]: def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri) parsed = urlparse(uri)

View File

@@ -47,7 +47,15 @@ class RetrieverTool(BaseTool):
keywords: str, keywords: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> list[Document]: ) -> list[Document]:
return self._run(keywords, run_manager.get_sync()) logger.info(
f"Retriever tool query: {keywords}", extra={"resources": self.resources}
)
documents = await self.retriever.query_relevant_documents_async(
keywords, self.resources
)
if not documents:
return "No results found from the local knowledge base."
return [doc.to_dict() for doc in documents]
def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None: def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:

View File

@@ -56,18 +56,59 @@ def test_retriever_abstract_methods():
def list_resources(self, query=None): def list_resources(self, query=None):
return [Resource(uri="uri", title="title")] return [Resource(uri="uri", title="title")]
async def list_resources_async(self, query=None):
return [Resource(uri="uri", title="title")]
def query_relevant_documents(self, query, resources=[]): def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])] return [Document(id="id", chunks=[])]
async def query_relevant_documents_async(self, query, resources=[]):
return [Document(id="id", chunks=[])]
retriever = DummyRetriever() retriever = DummyRetriever()
# Test synchronous methods
resources = retriever.list_resources() resources = retriever.list_resources()
assert isinstance(resources, list) assert isinstance(resources, list)
assert isinstance(resources[0], Resource) assert isinstance(resources[0], Resource)
assert resources[0].uri == "uri"
docs = retriever.query_relevant_documents("query", resources) docs = retriever.query_relevant_documents("query", resources)
assert isinstance(docs, list) assert isinstance(docs, list)
assert isinstance(docs[0], Document) assert isinstance(docs[0], Document)
assert docs[0].id == "id"
def test_retriever_cannot_instantiate(): def test_retriever_cannot_instantiate():
with pytest.raises(TypeError): with pytest.raises(TypeError):
Retriever() Retriever()
@pytest.mark.asyncio
async def test_retriever_async_methods():
"""Test that async methods work correctly in DummyRetriever."""
class DummyRetriever(Retriever):
def list_resources(self, query=None):
return [Resource(uri="uri", title="title")]
async def list_resources_async(self, query=None):
return [Resource(uri="uri_async", title="title_async")]
def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])]
async def query_relevant_documents_async(self, query, resources=[]):
return [Document(id="id_async", chunks=[])]
retriever = DummyRetriever()
# Test async list_resources
resources = await retriever.list_resources_async()
assert isinstance(resources, list)
assert isinstance(resources[0], Resource)
assert resources[0].uri == "uri_async"
# Test async query_relevant_documents
docs = await retriever.query_relevant_documents_async("query", resources)
assert isinstance(docs, list)
assert isinstance(docs[0], Document)
assert docs[0].id == "id_async"

View File

@@ -66,18 +66,20 @@ async def test_retriever_tool_arun():
mock_retriever = Mock(spec=Retriever) mock_retriever = Mock(spec=Retriever)
chunk = Chunk(content="async content", similarity=0.8) chunk = Chunk(content="async content", similarity=0.8)
doc = Document(id="doc2", chunks=[chunk]) doc = Document(id="doc2", chunks=[chunk])
mock_retriever.query_relevant_documents.return_value = [doc]
# Mock the async method
async def mock_async_query(*args, **kwargs):
return [doc]
mock_retriever.query_relevant_documents_async = mock_async_query
resources = [Resource(uri="test://uri", title="Test")] resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources) tool = RetrieverTool(retriever=mock_retriever, resources=resources)
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun) mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
mock_sync_manager = Mock(spec=CallbackManagerForToolRun)
mock_run_manager.get_sync.return_value = mock_sync_manager
result = await tool._arun("async keywords", mock_run_manager) result = await tool._arun("async keywords", mock_run_manager)
mock_run_manager.get_sync.assert_called_once()
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 1 assert len(result) == 1
assert result[0] == doc.to_dict() assert result[0] == doc.to_dict()