diff --git a/src/rag/dify.py b/src/rag/dify.py index 527e16e..6e10082 100644 --- a/src/rag/dify.py +++ b/src/rag/dify.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import os from urllib.parse import urlparse @@ -94,6 +95,17 @@ class DifyProvider(Retriever): return list(all_documents.values()) + async def query_relevant_documents_async( + self, query: str, resources: list[Resource] = [] + ) -> list[Document]: + """ + Asynchronous version of query_relevant_documents. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread( + self.query_relevant_documents, query, resources + ) + def list_resources(self, query: str | None = None) -> list[Resource]: headers = { "Authorization": f"Bearer {self.api_key}", @@ -124,6 +136,13 @@ class DifyProvider(Retriever): return resources + async def list_resources_async(self, query: str | None = None) -> list[Resource]: + """ + Asynchronous version of list_resources. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread(self.list_resources, query) + def parse_uri(uri: str) -> tuple[str, str]: parsed = urlparse(uri) diff --git a/src/rag/milvus.py b/src/rag/milvus.py index 4c9d86d..0c5d23c 100644 --- a/src/rag/milvus.py +++ b/src/rag/milvus.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import hashlib import logging import re @@ -514,6 +515,13 @@ class MilvusRetriever(Retriever): return self._list_local_markdown_resources() return resources + async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]: + """ + Asynchronous version of list_resources. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread(self.list_resources, query) + def _list_local_markdown_resources(self) -> List[Resource]: """Return local example markdown files as ``Resource`` objects. @@ -661,6 +669,17 @@ class MilvusRetriever(Retriever): except Exception as e: raise RuntimeError(f"Failed to query documents from Milvus: {str(e)}") + async def query_relevant_documents_async( + self, query: str, resources: Optional[List[Resource]] = None + ) -> List[Document]: + """ + Asynchronous version of query_relevant_documents. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread( + self.query_relevant_documents, query, resources + ) + def create_collection(self) -> None: """Public hook ensuring collection exists (explicit initialization).""" if not self.client: diff --git a/src/rag/moi.py b/src/rag/moi.py index 7f05a84..c3a0976 100644 --- a/src/rag/moi.py +++ b/src/rag/moi.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import os from urllib.parse import urlparse @@ -108,6 +109,17 @@ class MOIProvider(Retriever): return list(docs.values()) + async def query_relevant_documents_async( + self, query: str, resources: list[Resource] = [] + ) -> list[Document]: + """ + Asynchronous version of query_relevant_documents. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread( + self.query_relevant_documents, query, resources + ) + def list_resources(self, query: str | None = None) -> list[Resource]: """ List resources from MOI API with optional query filtering and limit support. @@ -144,6 +156,13 @@ class MOIProvider(Retriever): return resources + async def list_resources_async(self, query: str | None = None) -> list[Resource]: + """ + Asynchronous version of list_resources. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread(self.list_resources, query) + def _parse_uri(self, uri: str) -> tuple[str, str]: """ Parse URI to extract dataset ID and document ID. diff --git a/src/rag/qdrant.py b/src/rag/qdrant.py index fd01741..6d71b9b 100644 --- a/src/rag/qdrant.py +++ b/src/rag/qdrant.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import hashlib import logging import uuid @@ -351,6 +352,13 @@ class QdrantProvider(Retriever): return self._list_local_markdown_resources() return resources + async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]: + """ + Asynchronous version of list_resources. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread(self.list_resources, query) + def _list_local_markdown_resources(self) -> List[Resource]: current_file = Path(__file__) project_root = current_file.parent.parent.parent @@ -419,6 +427,17 @@ class QdrantProvider(Retriever): return list(documents.values()) + async def query_relevant_documents_async( + self, query: str, resources: Optional[List[Resource]] = None + ) -> List[Document]: + """ + Asynchronous version of query_relevant_documents. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread( + self.query_relevant_documents, query, resources + ) + def create_collection(self) -> None: if not self.client: self._connect() diff --git a/src/rag/ragflow.py b/src/rag/ragflow.py index ad8b81c..eb87dce 100644 --- a/src/rag/ragflow.py +++ b/src/rag/ragflow.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import os from typing import List, Optional from urllib.parse import urlparse @@ -98,6 +99,17 @@ class RAGFlowProvider(Retriever): return list(docs.values()) + async def query_relevant_documents_async( + self, query: str, resources: list[Resource] = [] + ) -> list[Document]: + """ + Asynchronous version of query_relevant_documents. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread( + self.query_relevant_documents, query, resources + ) + def list_resources(self, query: str | None = None) -> list[Resource]: headers = { "Authorization": f"Bearer {self.api_key}", @@ -128,6 +140,13 @@ class RAGFlowProvider(Retriever): return resources + async def list_resources_async(self, query: str | None = None) -> list[Resource]: + """ + Asynchronous version of list_resources. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread(self.list_resources, query) + def parse_uri(uri: str) -> tuple[str, str]: parsed = urlparse(uri) diff --git a/src/rag/retriever.py b/src/rag/retriever.py index 799b983..f32ffc8 100644 --- a/src/rag/retriever.py +++ b/src/rag/retriever.py @@ -67,7 +67,18 @@ class Retriever(abc.ABC): @abc.abstractmethod def list_resources(self, query: str | None = None) -> list[Resource]: """ - List resources from the rag provider. + List resources from the rag provider (synchronous version). + """ + pass + + @abc.abstractmethod + async def list_resources_async(self, query: str | None = None) -> list[Resource]: + """ + List resources from the rag provider (asynchronous version). + + Implementations should choose between: + - Providing native async I/O operations for true non-blocking behavior + - Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available """ pass @@ -76,7 +87,20 @@ class Retriever(abc.ABC): self, query: str, resources: list[Resource] = [] ) -> list[Document]: """ - Query relevant documents from the resources. + Query relevant documents from the resources (synchronous version). + """ + pass + + @abc.abstractmethod + async def query_relevant_documents_async( + self, query: str, resources: list[Resource] = [] + ) -> list[Document]: + """ + Query relevant documents from the resources (asynchronous version). + + Implementations should choose between: + - Providing native async I/O operations for true non-blocking behavior + - Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available """ pass diff --git a/src/rag/vikingdb_knowledge_base.py b/src/rag/vikingdb_knowledge_base.py index bbab88b..ccd1466 100644 --- a/src/rag/vikingdb_knowledge_base.py +++ b/src/rag/vikingdb_knowledge_base.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import hashlib import hmac import json @@ -255,6 +256,17 @@ class VikingDBKnowledgeBaseProvider(Retriever): return list(all_documents.values()) + async def query_relevant_documents_async( + self, query: str, resources: list[Resource] = [] + ) -> list[Document]: + """ + Asynchronous version of query_relevant_documents. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread( + self.query_relevant_documents, query, resources + ) + def list_resources(self, query: str | None = None) -> list[Resource]: """ List resources (knowledge bases) from the knowledge base service @@ -291,6 +303,13 @@ class VikingDBKnowledgeBaseProvider(Retriever): return resources + async def list_resources_async(self, query: str | None = None) -> list[Resource]: + """ + Asynchronous version of list_resources. + Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop. + """ + return await asyncio.to_thread(self.list_resources, query) + def parse_uri(uri: str) -> tuple[str, str]: parsed = urlparse(uri) diff --git a/src/tools/retriever.py b/src/tools/retriever.py index 091185e..476adb3 100644 --- a/src/tools/retriever.py +++ b/src/tools/retriever.py @@ -47,7 +47,15 @@ class RetrieverTool(BaseTool): keywords: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> list[Document]: - return self._run(keywords, run_manager.get_sync()) + logger.info( + f"Retriever tool query: {keywords}", extra={"resources": self.resources} + ) + documents = await self.retriever.query_relevant_documents_async( + keywords, self.resources + ) + if not documents: + return "No results found from the local knowledge base." + return [doc.to_dict() for doc in documents] def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None: diff --git a/tests/unit/rag/test_retriever.py b/tests/unit/rag/test_retriever.py index 582f2af..1286e96 100644 --- a/tests/unit/rag/test_retriever.py +++ b/tests/unit/rag/test_retriever.py @@ -56,18 +56,59 @@ def test_retriever_abstract_methods(): def list_resources(self, query=None): return [Resource(uri="uri", title="title")] + async def list_resources_async(self, query=None): + return [Resource(uri="uri", title="title")] + def query_relevant_documents(self, query, resources=[]): return [Document(id="id", chunks=[])] + async def query_relevant_documents_async(self, query, resources=[]): + return [Document(id="id", chunks=[])] + retriever = DummyRetriever() + # Test synchronous methods resources = retriever.list_resources() assert isinstance(resources, list) assert isinstance(resources[0], Resource) + assert resources[0].uri == "uri" + docs = retriever.query_relevant_documents("query", resources) assert isinstance(docs, list) assert isinstance(docs[0], Document) + assert docs[0].id == "id" def test_retriever_cannot_instantiate(): with pytest.raises(TypeError): Retriever() + + +@pytest.mark.asyncio +async def test_retriever_async_methods(): + """Test that async methods work correctly in DummyRetriever.""" + class DummyRetriever(Retriever): + def list_resources(self, query=None): + return [Resource(uri="uri", title="title")] + + async def list_resources_async(self, query=None): + return [Resource(uri="uri_async", title="title_async")] + + def query_relevant_documents(self, query, resources=[]): + return [Document(id="id", chunks=[])] + + async def query_relevant_documents_async(self, query, resources=[]): + return [Document(id="id_async", chunks=[])] + + retriever = DummyRetriever() + + # Test async list_resources + resources = await retriever.list_resources_async() + assert isinstance(resources, list) + assert isinstance(resources[0], Resource) + assert resources[0].uri == "uri_async" + + # Test async query_relevant_documents + docs = await retriever.query_relevant_documents_async("query", resources) + assert isinstance(docs, list) + assert isinstance(docs[0], Document) + assert docs[0].id == "id_async" diff --git a/tests/unit/tools/test_tools_retriever.py b/tests/unit/tools/test_tools_retriever.py index e4aaee0..18f1b15 100644 --- a/tests/unit/tools/test_tools_retriever.py +++ b/tests/unit/tools/test_tools_retriever.py @@ -66,18 +66,20 @@ async def test_retriever_tool_arun(): mock_retriever = Mock(spec=Retriever) chunk = Chunk(content="async content", similarity=0.8) doc = Document(id="doc2", chunks=[chunk]) - mock_retriever.query_relevant_documents.return_value = [doc] + + # Mock the async method + async def mock_async_query(*args, **kwargs): + return [doc] + + mock_retriever.query_relevant_documents_async = mock_async_query resources = [Resource(uri="test://uri", title="Test")] tool = RetrieverTool(retriever=mock_retriever, resources=resources) mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun) - mock_sync_manager = Mock(spec=CallbackManagerForToolRun) - mock_run_manager.get_sync.return_value = mock_sync_manager result = await tool._arun("async keywords", mock_run_manager) - mock_run_manager.get_sync.assert_called_once() assert isinstance(result, list) assert len(result) == 1 assert result[0] == doc.to_dict()