# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT import hashlib import logging import uuid from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Set from langchain_openai import OpenAIEmbeddings from langchain_qdrant import QdrantVectorStore from openai import OpenAI from qdrant_client import QdrantClient, grpc from qdrant_client.models import ( Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams, ) from src.config.loader import get_bool_env, get_int_env, get_str_env from src.rag.retriever import Chunk, Document, Resource, Retriever logger = logging.getLogger(__name__) SCROLL_SIZE = 64 class DashscopeEmbeddings: def __init__(self, **kwargs: Any) -> None: self._client: OpenAI = OpenAI( api_key=kwargs.get("api_key", ""), base_url=kwargs.get("base_url", "") ) self._model: str = kwargs.get("model", "") self._encoding_format: str = kwargs.get("encoding_format", "float") def _embed(self, texts: Sequence[str]) -> List[List[float]]: clean_texts = [t if isinstance(t, str) else str(t) for t in texts] if not clean_texts: return [] resp = self._client.embeddings.create( model=self._model, input=clean_texts, encoding_format=self._encoding_format, ) return [d.embedding for d in resp.data] def embed_query(self, text: str) -> List[float]: embeddings = self._embed([text]) return embeddings[0] if embeddings else [] def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embed(texts) class QdrantProvider(Retriever): def __init__(self) -> None: self.location: str = get_str_env("QDRANT_LOCATION", ":memory:") self.api_key: str = get_str_env("QDRANT_API_KEY", "") self.collection_name: str = get_str_env("QDRANT_COLLECTION", "documents") top_k_raw = get_str_env("QDRANT_TOP_K", "10") self.top_k: int = int(top_k_raw) if top_k_raw.isdigit() else 10 self.embedding_model_name = get_str_env("QDRANT_EMBEDDING_MODEL") self.embedding_api_key = get_str_env("QDRANT_EMBEDDING_API_KEY") self.embedding_base_url = get_str_env("QDRANT_EMBEDDING_BASE_URL") self.embedding_dim: int = self._get_embedding_dimension( self.embedding_model_name ) self.embedding_provider = get_str_env("QDRANT_EMBEDDING_PROVIDER", "openai") self.auto_load_examples: bool = get_bool_env("QDRANT_AUTO_LOAD_EXAMPLES", True) self.examples_dir: str = get_str_env("QDRANT_EXAMPLES_DIR", "examples") self.chunk_size: int = get_int_env("QDRANT_CHUNK_SIZE", 4000) self._init_embedding_model() self.client: Any = None self.vector_store: Any = None def _init_embedding_model(self) -> None: kwargs = { "api_key": self.embedding_api_key, "model": self.embedding_model_name, "base_url": self.embedding_base_url, "encoding_format": "float", "dimensions": self.embedding_dim, } if self.embedding_provider.lower() == "openai": self.embedding_model = OpenAIEmbeddings(**kwargs) elif self.embedding_provider.lower() == "dashscope": self.embedding_model = DashscopeEmbeddings(**kwargs) else: raise ValueError( f"Unsupported embedding provider: {self.embedding_provider}. " "Supported providers: openai, dashscope" ) def _get_embedding_dimension(self, model_name: str) -> int: embedding_dims = { "text-embedding-ada-002": 1536, "text-embedding-v4": 2048, } explicit_dim = get_int_env("QDRANT_EMBEDDING_DIM", 0) if explicit_dim > 0: return explicit_dim return embedding_dims.get(model_name, 1536) def _ensure_collection_exists(self) -> None: if not self.client.collection_exists(self.collection_name): self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( size=self.embedding_dim, distance=Distance.COSINE ), ) logger.info("Created Qdrant collection: %s", self.collection_name) def _load_example_files(self) -> None: current_file = Path(__file__) project_root = current_file.parent.parent.parent examples_path = project_root / self.examples_dir if not examples_path.exists(): logger.info("Examples directory not found: %s", examples_path) return logger.info("Loading example files from: %s", examples_path) md_files = list(examples_path.glob("*.md")) if not md_files: logger.info("No markdown files found in examples directory") return existing_docs = self._get_existing_document_ids() loaded_count = 0 for md_file in md_files: doc_id = self._generate_doc_id(md_file) if doc_id in existing_docs: continue try: content = md_file.read_text(encoding="utf-8") title = self._extract_title_from_markdown(content, md_file.name) chunks = self._split_content(content) for i, chunk in enumerate(chunks): chunk_id = f"{doc_id}_chunk_{i}" if len(chunks) > 1 else doc_id self._insert_document_chunk( doc_id=chunk_id, content=chunk, title=title, url=f"qdrant://{self.collection_name}/{md_file.name}", metadata={"source": "examples", "file": md_file.name}, ) loaded_count += 1 logger.debug("Loaded example markdown: %s", md_file.name) except Exception as e: logger.warning("Error loading %s: %s", md_file.name, e) logger.info("Successfully loaded %d example files into Qdrant", loaded_count) def _generate_doc_id(self, file_path: Path) -> str: file_stat = file_path.stat() content_hash = hashlib.md5( f"{file_path.name}_{file_stat.st_size}_{file_stat.st_mtime}".encode() ).hexdigest()[:8] return f"example_{file_path.stem}_{content_hash}" def _extract_title_from_markdown(self, content: str, filename: str) -> str: lines = content.split("\n") for line in lines: line = line.strip() if line.startswith("# "): return line[2:].strip() return filename.replace(".md", "").replace("_", " ").title() def _split_content(self, content: str) -> List[str]: if len(content) <= self.chunk_size: return [content] chunks = [] paragraphs = content.split("\n\n") current_chunk = "" for paragraph in paragraphs: if len(current_chunk) + len(paragraph) <= self.chunk_size: current_chunk += paragraph + "\n\n" else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = paragraph + "\n\n" if current_chunk: chunks.append(current_chunk.strip()) return chunks def _string_to_uuid(self, text: str) -> str: namespace = uuid.NAMESPACE_DNS return str(uuid.uuid5(namespace, text)) def _scroll_all_points( self, scroll_filter: Optional[Filter] = None, with_payload: bool = True, with_vectors: bool = False, ) -> List[Any]: results = [] next_offset = None stop_scrolling = False while not stop_scrolling: points, next_offset = self.client.scroll( collection_name=self.collection_name, scroll_filter=scroll_filter, limit=SCROLL_SIZE, offset=next_offset, with_payload=with_payload, with_vectors=with_vectors, ) stop_scrolling = next_offset is None or ( isinstance(next_offset, grpc.PointId) and getattr(next_offset, "num", 0) == 0 and getattr(next_offset, "uuid", "") == "" ) results.extend(points) return results def _get_existing_document_ids(self) -> Set[str]: try: points = self._scroll_all_points(with_payload=True, with_vectors=False) return { point.payload.get("doc_id", str(point.id)) for point in points if point.payload } except Exception: return set() def _insert_document_chunk( self, doc_id: str, content: str, title: str, url: str, metadata: Dict[str, Any] ) -> None: embedding = self._get_embedding(content) payload = { "doc_id": doc_id, "content": content, "title": title, "url": url, **metadata, } point_id = self._string_to_uuid(doc_id) point = PointStruct(id=point_id, vector=embedding, payload=payload) self.client.upsert( collection_name=self.collection_name, points=[point], wait=True ) def _connect(self) -> None: client_kwargs = {"location": self.location} if self.api_key: client_kwargs["api_key"] = self.api_key self.client = QdrantClient(**client_kwargs) self._ensure_collection_exists() try: self.vector_store = QdrantVectorStore( client=self.client, collection_name=self.collection_name, embedding=self.embedding_model, ) except Exception: self.vector_store = None def _get_embedding(self, text: str) -> List[float]: return self.embedding_model.embed_query(text=text.strip()) def list_resources(self, query: Optional[str] = None) -> List[Resource]: resources: List[Resource] = [] if not self.client: try: self._connect() except Exception: return self._list_local_markdown_resources() try: if query and self.vector_store: docs = self.vector_store.similarity_search( query, k=100, filter={"source": "examples"} ) for d in docs: meta = d.metadata or {} uri = meta.get("url", "") or f"qdrant://{meta.get('id', '')}" if any(r.uri == uri for r in resources): continue resources.append( Resource( uri=uri, title=meta.get("title", "") or meta.get("id", "Unnamed"), description="Stored Qdrant document", ) ) else: all_points = self._scroll_all_points( scroll_filter=Filter( must=[ FieldCondition( key="source", match=MatchValue(value="examples") ) ] ), with_payload=True, with_vectors=False, ) for point in all_points: payload = point.payload or {} doc_id = payload.get("doc_id", str(point.id)) uri = payload.get("url", "") or f"qdrant://{doc_id}" resources.append( Resource( uri=uri, title=payload.get("title", "") or doc_id, description="Stored Qdrant document", ) ) logger.info( "Successfully listed %d resources from Qdrant collection: %s", len(resources), self.collection_name, ) except Exception: logger.warning( "Failed to query Qdrant for resources, falling back to local examples." ) return self._list_local_markdown_resources() return resources def _list_local_markdown_resources(self) -> List[Resource]: current_file = Path(__file__) project_root = current_file.parent.parent.parent examples_path = project_root / self.examples_dir if not examples_path.exists(): return [] md_files = list(examples_path.glob("*.md")) resources: list[Resource] = [] for md_file in md_files: try: content = md_file.read_text(encoding="utf-8", errors="ignore") title = self._extract_title_from_markdown(content, md_file.name) uri = f"qdrant://{self.collection_name}/{md_file.name}" resources.append( Resource( uri=uri, title=title, description="Local markdown example (not yet ingested)", ) ) except Exception: continue return resources def query_relevant_documents( self, query: str, resources: Optional[List[Resource]] = None ) -> List[Document]: resources = resources or [] if not self.client: self._connect() query_embedding = self._get_embedding(query) search_results = self.client.query_points( collection_name=self.collection_name, query=query_embedding, limit=self.top_k, with_payload=True, ).points documents = {} for result in search_results: payload = result.payload or {} doc_id = payload.get("doc_id", str(result.id)) content = payload.get("content", "") title = payload.get("title", "") url = payload.get("url", "") score = result.score if resources: doc_in_resources = False for resource in resources: if (url and url in resource.uri) or doc_id in resource.uri: doc_in_resources = True break if not doc_in_resources: continue if doc_id not in documents: documents[doc_id] = Document(id=doc_id, url=url, title=title, chunks=[]) chunk = Chunk(content=content, similarity=score) documents[doc_id].chunks.append(chunk) return list(documents.values()) def create_collection(self) -> None: if not self.client: self._connect() else: self._ensure_collection_exists() def load_examples(self, force_reload: bool = False) -> None: if not self.client: self._connect() if force_reload: self._clear_example_documents() self._load_example_files() def _clear_example_documents(self) -> None: try: all_points = self._scroll_all_points( scroll_filter=Filter( must=[ FieldCondition(key="source", match=MatchValue(value="examples")) ] ), with_payload=False, with_vectors=False, ) if all_points: point_ids = [str(point.id) for point in all_points] self.client.delete( collection_name=self.collection_name, points_selector=point_ids ) logger.info("Cleared %d existing example documents", len(point_ids)) except Exception as e: logger.warning("Could not clear existing examples: %s", e) def get_loaded_examples(self) -> List[Dict[str, str]]: if not self.client: self._connect() all_points = self._scroll_all_points( scroll_filter=Filter( must=[FieldCondition(key="source", match=MatchValue(value="examples"))] ), with_payload=True, with_vectors=False, ) examples = [] for point in all_points: payload = point.payload or {} examples.append( { "id": payload.get("doc_id", str(point.id)), "title": payload.get("title", ""), "file": payload.get("file", ""), "url": payload.get("url", ""), } ) return examples def close(self) -> None: if hasattr(self, "client") and self.client: try: if hasattr(self.client, "close"): self.client.close() self.client = None self.vector_store = None except Exception as e: logger.warning("Exception occurred while closing QdrantProvider: %s", e) def __del__(self) -> None: self.close() def load_examples() -> None: auto_load_examples = get_bool_env("QDRANT_AUTO_LOAD_EXAMPLES", False) rag_provider = get_str_env("RAG_PROVIDER", "") if rag_provider == "qdrant" and auto_load_examples: provider = QdrantProvider() provider.load_examples()