mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 15:24:48 +08:00
refactor: Refactors the retriever function to use async/await (#821)
* refactor: Refactors the retriever function to use async/await
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -94,6 +95,17 @@ class DifyProvider(Retriever):
|
||||
|
||||
return list(all_documents.values())
|
||||
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Asynchronous version of query_relevant_documents.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self.query_relevant_documents, query, resources
|
||||
)
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
@@ -124,6 +136,13 @@ class DifyProvider(Retriever):
|
||||
|
||||
return resources
|
||||
|
||||
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
Asynchronous version of list_resources.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.list_resources, query)
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> tuple[str, str]:
|
||||
parsed = urlparse(uri)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
@@ -514,6 +515,13 @@ class MilvusRetriever(Retriever):
|
||||
return self._list_local_markdown_resources()
|
||||
return resources
|
||||
|
||||
async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]:
|
||||
"""
|
||||
Asynchronous version of list_resources.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.list_resources, query)
|
||||
|
||||
def _list_local_markdown_resources(self) -> List[Resource]:
|
||||
"""Return local example markdown files as ``Resource`` objects.
|
||||
|
||||
@@ -661,6 +669,17 @@ class MilvusRetriever(Retriever):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to query documents from Milvus: {str(e)}")
|
||||
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: Optional[List[Resource]] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronous version of query_relevant_documents.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self.query_relevant_documents, query, resources
|
||||
)
|
||||
|
||||
def create_collection(self) -> None:
|
||||
"""Public hook ensuring collection exists (explicit initialization)."""
|
||||
if not self.client:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -108,6 +109,17 @@ class MOIProvider(Retriever):
|
||||
|
||||
return list(docs.values())
|
||||
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Asynchronous version of query_relevant_documents.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self.query_relevant_documents, query, resources
|
||||
)
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
List resources from MOI API with optional query filtering and limit support.
|
||||
@@ -144,6 +156,13 @@ class MOIProvider(Retriever):
|
||||
|
||||
return resources
|
||||
|
||||
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
Asynchronous version of list_resources.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.list_resources, query)
|
||||
|
||||
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse URI to extract dataset ID and document ID.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
@@ -351,6 +352,13 @@ class QdrantProvider(Retriever):
|
||||
return self._list_local_markdown_resources()
|
||||
return resources
|
||||
|
||||
async def list_resources_async(self, query: Optional[str] = None) -> List[Resource]:
|
||||
"""
|
||||
Asynchronous version of list_resources.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.list_resources, query)
|
||||
|
||||
def _list_local_markdown_resources(self) -> List[Resource]:
|
||||
current_file = Path(__file__)
|
||||
project_root = current_file.parent.parent.parent
|
||||
@@ -419,6 +427,17 @@ class QdrantProvider(Retriever):
|
||||
|
||||
return list(documents.values())
|
||||
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: Optional[List[Resource]] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronous version of query_relevant_documents.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self.query_relevant_documents, query, resources
|
||||
)
|
||||
|
||||
def create_collection(self) -> None:
|
||||
if not self.client:
|
||||
self._connect()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -98,6 +99,17 @@ class RAGFlowProvider(Retriever):
|
||||
|
||||
return list(docs.values())
|
||||
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Asynchronous version of query_relevant_documents.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self.query_relevant_documents, query, resources
|
||||
)
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
@@ -128,6 +140,13 @@ class RAGFlowProvider(Retriever):
|
||||
|
||||
return resources
|
||||
|
||||
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
Asynchronous version of list_resources.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.list_resources, query)
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> tuple[str, str]:
|
||||
parsed = urlparse(uri)
|
||||
|
||||
@@ -67,7 +67,18 @@ class Retriever(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
List resources from the rag provider.
|
||||
List resources from the rag provider (synchronous version).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
List resources from the rag provider (asynchronous version).
|
||||
|
||||
Implementations should choose between:
|
||||
- Providing native async I/O operations for true non-blocking behavior
|
||||
- Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -76,7 +87,20 @@ class Retriever(abc.ABC):
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Query relevant documents from the resources.
|
||||
Query relevant documents from the resources (synchronous version).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Query relevant documents from the resources (asynchronous version).
|
||||
|
||||
Implementations should choose between:
|
||||
- Providing native async I/O operations for true non-blocking behavior
|
||||
- Using asyncio.to_thread() to wrap the synchronous version if async I/O is not available
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
@@ -255,6 +256,17 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
|
||||
return list(all_documents.values())
|
||||
|
||||
async def query_relevant_documents_async(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Asynchronous version of query_relevant_documents.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self.query_relevant_documents, query, resources
|
||||
)
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
List resources (knowledge bases) from the knowledge base service
|
||||
@@ -291,6 +303,13 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
|
||||
return resources
|
||||
|
||||
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
Asynchronous version of list_resources.
|
||||
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.list_resources, query)
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> tuple[str, str]:
|
||||
parsed = urlparse(uri)
|
||||
|
||||
Reference in New Issue
Block a user