refactor: Refactors the retriever function to use async/await (#821)

* refactor: Refactors the retriever function to use async/await
This commit is contained in:
Xun
2026-01-20 19:56:26 +08:00
committed by GitHub
parent 2ed0eeb107
commit 0e64c52975
10 changed files with 196 additions and 7 deletions

View File

@@ -56,18 +56,59 @@ def test_retriever_abstract_methods():
def list_resources(self, query=None):
return [Resource(uri="uri", title="title")]
async def list_resources_async(self, query=None):
return [Resource(uri="uri", title="title")]
def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])]
async def query_relevant_documents_async(self, query, resources=[]):
return [Document(id="id", chunks=[])]
retriever = DummyRetriever()
# Test synchronous methods
resources = retriever.list_resources()
assert isinstance(resources, list)
assert isinstance(resources[0], Resource)
assert resources[0].uri == "uri"
docs = retriever.query_relevant_documents("query", resources)
assert isinstance(docs, list)
assert isinstance(docs[0], Document)
assert docs[0].id == "id"
def test_retriever_cannot_instantiate():
with pytest.raises(TypeError):
Retriever()
@pytest.mark.asyncio
async def test_retriever_async_methods():
"""Test that async methods work correctly in DummyRetriever."""
class DummyRetriever(Retriever):
def list_resources(self, query=None):
return [Resource(uri="uri", title="title")]
async def list_resources_async(self, query=None):
return [Resource(uri="uri_async", title="title_async")]
def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])]
async def query_relevant_documents_async(self, query, resources=[]):
return [Document(id="id_async", chunks=[])]
retriever = DummyRetriever()
# Test async list_resources
resources = await retriever.list_resources_async()
assert isinstance(resources, list)
assert isinstance(resources[0], Resource)
assert resources[0].uri == "uri_async"
# Test async query_relevant_documents
docs = await retriever.query_relevant_documents_async("query", resources)
assert isinstance(docs, list)
assert isinstance(docs[0], Document)
assert docs[0].id == "id_async"