mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
refactor: Refactors the retriever function to use async/await (#821)
* refactor: Refactors the retriever function to use async/await
This commit is contained in:
@@ -56,18 +56,59 @@ def test_retriever_abstract_methods():
|
||||
def list_resources(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
async def list_resources_async(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
def query_relevant_documents(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
async def query_relevant_documents_async(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
retriever = DummyRetriever()
|
||||
# Test synchronous methods
|
||||
resources = retriever.list_resources()
|
||||
assert isinstance(resources, list)
|
||||
assert isinstance(resources[0], Resource)
|
||||
assert resources[0].uri == "uri"
|
||||
|
||||
docs = retriever.query_relevant_documents("query", resources)
|
||||
assert isinstance(docs, list)
|
||||
assert isinstance(docs[0], Document)
|
||||
assert docs[0].id == "id"
|
||||
|
||||
|
||||
def test_retriever_cannot_instantiate():
|
||||
with pytest.raises(TypeError):
|
||||
Retriever()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_async_methods():
|
||||
"""Test that async methods work correctly in DummyRetriever."""
|
||||
class DummyRetriever(Retriever):
|
||||
def list_resources(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
async def list_resources_async(self, query=None):
|
||||
return [Resource(uri="uri_async", title="title_async")]
|
||||
|
||||
def query_relevant_documents(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
async def query_relevant_documents_async(self, query, resources=[]):
|
||||
return [Document(id="id_async", chunks=[])]
|
||||
|
||||
retriever = DummyRetriever()
|
||||
|
||||
# Test async list_resources
|
||||
resources = await retriever.list_resources_async()
|
||||
assert isinstance(resources, list)
|
||||
assert isinstance(resources[0], Resource)
|
||||
assert resources[0].uri == "uri_async"
|
||||
|
||||
# Test async query_relevant_documents
|
||||
docs = await retriever.query_relevant_documents_async("query", resources)
|
||||
assert isinstance(docs, list)
|
||||
assert isinstance(docs[0], Document)
|
||||
assert docs[0].id == "id_async"
|
||||
|
||||
@@ -66,18 +66,20 @@ async def test_retriever_tool_arun():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
chunk = Chunk(content="async content", similarity=0.8)
|
||||
doc = Document(id="doc2", chunks=[chunk])
|
||||
mock_retriever.query_relevant_documents.return_value = [doc]
|
||||
|
||||
# Mock the async method
|
||||
async def mock_async_query(*args, **kwargs):
|
||||
return [doc]
|
||||
|
||||
mock_retriever.query_relevant_documents_async = mock_async_query
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
|
||||
mock_sync_manager = Mock(spec=CallbackManagerForToolRun)
|
||||
mock_run_manager.get_sync.return_value = mock_sync_manager
|
||||
|
||||
result = await tool._arun("async keywords", mock_run_manager)
|
||||
|
||||
mock_run_manager.get_sync.assert_called_once()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == doc.to_dict()
|
||||
|
||||
Reference in New Issue
Block a user