mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
* fix: ensure researcher agent uses web search tool instead of generating URLs (#702) - Add enforce_researcher_search configuration option (default: True) to control web search requirement - Strengthen researcher prompts in both English and Chinese with explicit instructions to use web_search tool - Implement validate_web_search_usage function to detect if web search tool was used during research - Add validation logic that warns when researcher doesn't use web search tool - Enhance logging for web search tools with special markers for easy tracking - Skip validation during unit tests to avoid test failures - Update _execute_agent_step to accept config parameter for proper configuration access This addresses issue #702 where the researcher agent was generating URLs on its own instead of using the web search tool. * fix: addressed the code review comment * fix the unit test error and update the code
505 lines
17 KiB
Python
505 lines
17 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import hashlib
|
|
import logging
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Sequence, Set
|
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain_qdrant import QdrantVectorStore
|
|
from openai import OpenAI
|
|
from qdrant_client import QdrantClient, grpc
|
|
from qdrant_client.models import (
|
|
Distance,
|
|
FieldCondition,
|
|
Filter,
|
|
MatchValue,
|
|
PointStruct,
|
|
VectorParams,
|
|
)
|
|
|
|
from src.config.loader import get_bool_env, get_int_env, get_str_env
|
|
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SCROLL_SIZE = 64
|
|
|
|
|
|
class DashscopeEmbeddings:
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
self._client: OpenAI = OpenAI(
|
|
api_key=kwargs.get("api_key", ""), base_url=kwargs.get("base_url", "")
|
|
)
|
|
self._model: str = kwargs.get("model", "")
|
|
self._encoding_format: str = kwargs.get("encoding_format", "float")
|
|
|
|
def _embed(self, texts: Sequence[str]) -> List[List[float]]:
|
|
clean_texts = [t if isinstance(t, str) else str(t) for t in texts]
|
|
if not clean_texts:
|
|
return []
|
|
resp = self._client.embeddings.create(
|
|
model=self._model,
|
|
input=clean_texts,
|
|
encoding_format=self._encoding_format,
|
|
)
|
|
return [d.embedding for d in resp.data]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
embeddings = self._embed([text])
|
|
return embeddings[0] if embeddings else []
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
return self._embed(texts)
|
|
|
|
|
|
class QdrantProvider(Retriever):
|
|
def __init__(self) -> None:
|
|
self.location: str = get_str_env("QDRANT_LOCATION", ":memory:")
|
|
self.api_key: str = get_str_env("QDRANT_API_KEY", "")
|
|
self.collection_name: str = get_str_env("QDRANT_COLLECTION", "documents")
|
|
|
|
top_k_raw = get_str_env("QDRANT_TOP_K", "10")
|
|
self.top_k: int = int(top_k_raw) if top_k_raw.isdigit() else 10
|
|
|
|
self.embedding_model_name = get_str_env("QDRANT_EMBEDDING_MODEL")
|
|
self.embedding_api_key = get_str_env("QDRANT_EMBEDDING_API_KEY")
|
|
self.embedding_base_url = get_str_env("QDRANT_EMBEDDING_BASE_URL")
|
|
self.embedding_dim: int = self._get_embedding_dimension(
|
|
self.embedding_model_name
|
|
)
|
|
self.embedding_provider = get_str_env("QDRANT_EMBEDDING_PROVIDER", "openai")
|
|
|
|
self.auto_load_examples: bool = get_bool_env("QDRANT_AUTO_LOAD_EXAMPLES", True)
|
|
self.examples_dir: str = get_str_env("QDRANT_EXAMPLES_DIR", "examples")
|
|
self.chunk_size: int = get_int_env("QDRANT_CHUNK_SIZE", 4000)
|
|
|
|
self._init_embedding_model()
|
|
|
|
self.client: Any = None
|
|
self.vector_store: Any = None
|
|
|
|
def _init_embedding_model(self) -> None:
|
|
kwargs = {
|
|
"api_key": self.embedding_api_key,
|
|
"model": self.embedding_model_name,
|
|
"base_url": self.embedding_base_url,
|
|
"encoding_format": "float",
|
|
"dimensions": self.embedding_dim,
|
|
}
|
|
if self.embedding_provider.lower() == "openai":
|
|
self.embedding_model = OpenAIEmbeddings(**kwargs)
|
|
elif self.embedding_provider.lower() == "dashscope":
|
|
self.embedding_model = DashscopeEmbeddings(**kwargs)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported embedding provider: {self.embedding_provider}. "
|
|
"Supported providers: openai, dashscope"
|
|
)
|
|
|
|
def _get_embedding_dimension(self, model_name: str) -> int:
|
|
embedding_dims = {
|
|
"text-embedding-ada-002": 1536,
|
|
"text-embedding-v4": 2048,
|
|
}
|
|
|
|
explicit_dim = get_int_env("QDRANT_EMBEDDING_DIM", 0)
|
|
if explicit_dim > 0:
|
|
return explicit_dim
|
|
return embedding_dims.get(model_name, 1536)
|
|
|
|
def _ensure_collection_exists(self) -> None:
|
|
if not self.client.collection_exists(self.collection_name):
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=self.embedding_dim, distance=Distance.COSINE
|
|
),
|
|
)
|
|
logger.info("Created Qdrant collection: %s", self.collection_name)
|
|
|
|
def _load_example_files(self) -> None:
|
|
current_file = Path(__file__)
|
|
project_root = current_file.parent.parent.parent
|
|
examples_path = project_root / self.examples_dir
|
|
|
|
if not examples_path.exists():
|
|
logger.info("Examples directory not found: %s", examples_path)
|
|
return
|
|
|
|
logger.info("Loading example files from: %s", examples_path)
|
|
|
|
md_files = list(examples_path.glob("*.md"))
|
|
if not md_files:
|
|
logger.info("No markdown files found in examples directory")
|
|
return
|
|
|
|
existing_docs = self._get_existing_document_ids()
|
|
loaded_count = 0
|
|
for md_file in md_files:
|
|
doc_id = self._generate_doc_id(md_file)
|
|
|
|
if doc_id in existing_docs:
|
|
continue
|
|
|
|
try:
|
|
content = md_file.read_text(encoding="utf-8")
|
|
title = self._extract_title_from_markdown(content, md_file.name)
|
|
|
|
chunks = self._split_content(content)
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_id = f"{doc_id}_chunk_{i}" if len(chunks) > 1 else doc_id
|
|
self._insert_document_chunk(
|
|
doc_id=chunk_id,
|
|
content=chunk,
|
|
title=title,
|
|
url=f"qdrant://{self.collection_name}/{md_file.name}",
|
|
metadata={"source": "examples", "file": md_file.name},
|
|
)
|
|
|
|
loaded_count += 1
|
|
logger.debug("Loaded example markdown: %s", md_file.name)
|
|
|
|
except Exception as e:
|
|
logger.warning("Error loading %s: %s", md_file.name, e)
|
|
|
|
logger.info("Successfully loaded %d example files into Qdrant", loaded_count)
|
|
|
|
def _generate_doc_id(self, file_path: Path) -> str:
|
|
file_stat = file_path.stat()
|
|
content_hash = hashlib.md5(
|
|
f"{file_path.name}_{file_stat.st_size}_{file_stat.st_mtime}".encode()
|
|
).hexdigest()[:8]
|
|
return f"example_{file_path.stem}_{content_hash}"
|
|
|
|
def _extract_title_from_markdown(self, content: str, filename: str) -> str:
|
|
lines = content.split("\n")
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line.startswith("# "):
|
|
return line[2:].strip()
|
|
|
|
return filename.replace(".md", "").replace("_", " ").title()
|
|
|
|
def _split_content(self, content: str) -> List[str]:
|
|
if len(content) <= self.chunk_size:
|
|
return [content]
|
|
|
|
chunks = []
|
|
paragraphs = content.split("\n\n")
|
|
current_chunk = ""
|
|
|
|
for paragraph in paragraphs:
|
|
if len(current_chunk) + len(paragraph) <= self.chunk_size:
|
|
current_chunk += paragraph + "\n\n"
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = paragraph + "\n\n"
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
def _string_to_uuid(self, text: str) -> str:
|
|
namespace = uuid.NAMESPACE_DNS
|
|
return str(uuid.uuid5(namespace, text))
|
|
|
|
def _scroll_all_points(
|
|
self,
|
|
scroll_filter: Optional[Filter] = None,
|
|
with_payload: bool = True,
|
|
with_vectors: bool = False,
|
|
) -> List[Any]:
|
|
results = []
|
|
next_offset = None
|
|
stop_scrolling = False
|
|
|
|
while not stop_scrolling:
|
|
points, next_offset = self.client.scroll(
|
|
collection_name=self.collection_name,
|
|
scroll_filter=scroll_filter,
|
|
limit=SCROLL_SIZE,
|
|
offset=next_offset,
|
|
with_payload=with_payload,
|
|
with_vectors=with_vectors,
|
|
)
|
|
stop_scrolling = next_offset is None or (
|
|
isinstance(next_offset, grpc.PointId)
|
|
and getattr(next_offset, "num", 0) == 0
|
|
and getattr(next_offset, "uuid", "") == ""
|
|
)
|
|
results.extend(points)
|
|
|
|
return results
|
|
|
|
def _get_existing_document_ids(self) -> Set[str]:
|
|
try:
|
|
points = self._scroll_all_points(with_payload=True, with_vectors=False)
|
|
return {
|
|
point.payload.get("doc_id", str(point.id))
|
|
for point in points
|
|
if point.payload
|
|
}
|
|
except Exception:
|
|
return set()
|
|
|
|
def _insert_document_chunk(
|
|
self, doc_id: str, content: str, title: str, url: str, metadata: Dict[str, Any]
|
|
) -> None:
|
|
embedding = self._get_embedding(content)
|
|
|
|
payload = {
|
|
"doc_id": doc_id,
|
|
"content": content,
|
|
"title": title,
|
|
"url": url,
|
|
**metadata,
|
|
}
|
|
|
|
point_id = self._string_to_uuid(doc_id)
|
|
point = PointStruct(id=point_id, vector=embedding, payload=payload)
|
|
|
|
self.client.upsert(
|
|
collection_name=self.collection_name, points=[point], wait=True
|
|
)
|
|
|
|
def _connect(self) -> None:
|
|
client_kwargs = {"location": self.location}
|
|
if self.api_key:
|
|
client_kwargs["api_key"] = self.api_key
|
|
self.client = QdrantClient(**client_kwargs)
|
|
|
|
self._ensure_collection_exists()
|
|
|
|
try:
|
|
self.vector_store = QdrantVectorStore(
|
|
client=self.client,
|
|
collection_name=self.collection_name,
|
|
embedding=self.embedding_model,
|
|
)
|
|
except Exception:
|
|
self.vector_store = None
|
|
|
|
def _get_embedding(self, text: str) -> List[float]:
|
|
return self.embedding_model.embed_query(text=text.strip())
|
|
|
|
def list_resources(self, query: Optional[str] = None) -> List[Resource]:
|
|
resources: List[Resource] = []
|
|
|
|
if not self.client:
|
|
try:
|
|
self._connect()
|
|
except Exception:
|
|
return self._list_local_markdown_resources()
|
|
|
|
try:
|
|
if query and self.vector_store:
|
|
docs = self.vector_store.similarity_search(
|
|
query, k=100, filter={"source": "examples"}
|
|
)
|
|
for d in docs:
|
|
meta = d.metadata or {}
|
|
uri = meta.get("url", "") or f"qdrant://{meta.get('id', '')}"
|
|
if any(r.uri == uri for r in resources):
|
|
continue
|
|
resources.append(
|
|
Resource(
|
|
uri=uri,
|
|
title=meta.get("title", "") or meta.get("id", "Unnamed"),
|
|
description="Stored Qdrant document",
|
|
)
|
|
)
|
|
else:
|
|
all_points = self._scroll_all_points(
|
|
scroll_filter=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="source", match=MatchValue(value="examples")
|
|
)
|
|
]
|
|
),
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
|
|
for point in all_points:
|
|
payload = point.payload or {}
|
|
doc_id = payload.get("doc_id", str(point.id))
|
|
uri = payload.get("url", "") or f"qdrant://{doc_id}"
|
|
resources.append(
|
|
Resource(
|
|
uri=uri,
|
|
title=payload.get("title", "") or doc_id,
|
|
description="Stored Qdrant document",
|
|
)
|
|
)
|
|
|
|
logger.info(
|
|
"Successfully listed %d resources from Qdrant collection: %s",
|
|
len(resources),
|
|
self.collection_name,
|
|
)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to query Qdrant for resources, falling back to local examples."
|
|
)
|
|
return self._list_local_markdown_resources()
|
|
return resources
|
|
|
|
def _list_local_markdown_resources(self) -> List[Resource]:
|
|
current_file = Path(__file__)
|
|
project_root = current_file.parent.parent.parent
|
|
examples_path = project_root / self.examples_dir
|
|
if not examples_path.exists():
|
|
return []
|
|
|
|
md_files = list(examples_path.glob("*.md"))
|
|
resources: list[Resource] = []
|
|
for md_file in md_files:
|
|
try:
|
|
content = md_file.read_text(encoding="utf-8", errors="ignore")
|
|
title = self._extract_title_from_markdown(content, md_file.name)
|
|
uri = f"qdrant://{self.collection_name}/{md_file.name}"
|
|
resources.append(
|
|
Resource(
|
|
uri=uri,
|
|
title=title,
|
|
description="Local markdown example (not yet ingested)",
|
|
)
|
|
)
|
|
except Exception:
|
|
continue
|
|
return resources
|
|
|
|
def query_relevant_documents(
|
|
self, query: str, resources: Optional[List[Resource]] = None
|
|
) -> List[Document]:
|
|
resources = resources or []
|
|
if not self.client:
|
|
self._connect()
|
|
|
|
query_embedding = self._get_embedding(query)
|
|
|
|
search_results = self.client.query_points(
|
|
collection_name=self.collection_name,
|
|
query=query_embedding,
|
|
limit=self.top_k,
|
|
with_payload=True,
|
|
).points
|
|
|
|
documents = {}
|
|
|
|
for result in search_results:
|
|
payload = result.payload or {}
|
|
doc_id = payload.get("doc_id", str(result.id))
|
|
content = payload.get("content", "")
|
|
title = payload.get("title", "")
|
|
url = payload.get("url", "")
|
|
score = result.score
|
|
|
|
if resources:
|
|
doc_in_resources = False
|
|
for resource in resources:
|
|
if (url and url in resource.uri) or doc_id in resource.uri:
|
|
doc_in_resources = True
|
|
break
|
|
if not doc_in_resources:
|
|
continue
|
|
|
|
if doc_id not in documents:
|
|
documents[doc_id] = Document(id=doc_id, url=url, title=title, chunks=[])
|
|
|
|
chunk = Chunk(content=content, similarity=score)
|
|
documents[doc_id].chunks.append(chunk)
|
|
|
|
return list(documents.values())
|
|
|
|
def create_collection(self) -> None:
|
|
if not self.client:
|
|
self._connect()
|
|
else:
|
|
self._ensure_collection_exists()
|
|
|
|
def load_examples(self, force_reload: bool = False) -> None:
|
|
if not self.client:
|
|
self._connect()
|
|
|
|
if force_reload:
|
|
self._clear_example_documents()
|
|
|
|
self._load_example_files()
|
|
|
|
def _clear_example_documents(self) -> None:
|
|
try:
|
|
all_points = self._scroll_all_points(
|
|
scroll_filter=Filter(
|
|
must=[
|
|
FieldCondition(key="source", match=MatchValue(value="examples"))
|
|
]
|
|
),
|
|
with_payload=False,
|
|
with_vectors=False,
|
|
)
|
|
|
|
if all_points:
|
|
point_ids = [str(point.id) for point in all_points]
|
|
self.client.delete(
|
|
collection_name=self.collection_name, points_selector=point_ids
|
|
)
|
|
logger.info("Cleared %d existing example documents", len(point_ids))
|
|
|
|
except Exception as e:
|
|
logger.warning("Could not clear existing examples: %s", e)
|
|
|
|
def get_loaded_examples(self) -> List[Dict[str, str]]:
|
|
if not self.client:
|
|
self._connect()
|
|
|
|
all_points = self._scroll_all_points(
|
|
scroll_filter=Filter(
|
|
must=[FieldCondition(key="source", match=MatchValue(value="examples"))]
|
|
),
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
|
|
examples = []
|
|
for point in all_points:
|
|
payload = point.payload or {}
|
|
examples.append(
|
|
{
|
|
"id": payload.get("doc_id", str(point.id)),
|
|
"title": payload.get("title", ""),
|
|
"file": payload.get("file", ""),
|
|
"url": payload.get("url", ""),
|
|
}
|
|
)
|
|
|
|
return examples
|
|
|
|
def close(self) -> None:
|
|
if hasattr(self, "client") and self.client:
|
|
try:
|
|
if hasattr(self.client, "close"):
|
|
self.client.close()
|
|
self.client = None
|
|
self.vector_store = None
|
|
except Exception as e:
|
|
logger.warning("Exception occurred while closing QdrantProvider: %s", e)
|
|
|
|
def __del__(self) -> None:
|
|
self.close()
|
|
|
|
|
|
def load_examples() -> None:
|
|
auto_load_examples = get_bool_env("QDRANT_AUTO_LOAD_EXAMPLES", False)
|
|
rag_provider = get_str_env("RAG_PROVIDER", "")
|
|
if rag_provider == "qdrant" and auto_load_examples:
|
|
provider = QdrantProvider()
|
|
provider.load_examples()
|