diff --git a/.env.example b/.env.example index 89bb6b4..3c17dc0 100644 --- a/.env.example +++ b/.env.example @@ -35,6 +35,7 @@ TAVILY_API_KEY=tvly-xxx # RAGFLOW_API_URL="http://localhost:9388" # RAGFLOW_API_KEY="ragflow-xxx" # RAGFLOW_RETRIEVAL_SIZE=10 +# RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma # Optional, volcengine TTS for generating podcast VOLCENGINE_TTS_APPID=xxx diff --git a/README.md b/README.md index 5973ea5..5324e2c 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,7 @@ DeerFlow support private knowledgebase such as ragflow and vikingdb, so that you RAGFLOW_API_URL="http://localhost:9388" RAGFLOW_API_KEY="ragflow-xxx" RAGFLOW_RETRIEVAL_SIZE=10 + RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean ``` ## Features diff --git a/src/rag/ragflow.py b/src/rag/ragflow.py index 529ff2c..ecb791e 100644 --- a/src/rag/ragflow.py +++ b/src/rag/ragflow.py @@ -4,6 +4,7 @@ import os import requests from src.rag.retriever import Chunk, Document, Resource, Retriever +from typing import List, Optional from urllib.parse import urlparse @@ -15,6 +16,7 @@ class RAGFlowProvider(Retriever): api_url: str api_key: str page_size: int = 10 + cross_languages: Optional[List[str]] = None def __init__(self): api_url = os.getenv("RAGFLOW_API_URL") @@ -31,6 +33,11 @@ class RAGFlowProvider(Retriever): if page_size: self.page_size = int(page_size) + self.cross_languages = None + cross_languages = os.getenv("RAGFLOW_CROSS_LANGUAGES") + if cross_languages: + self.cross_languages = cross_languages.split(",") + def query_relevant_documents( self, query: str, resources: list[Resource] = [] ) -> list[Document]: @@ -55,6 +62,9 @@ class RAGFlowProvider(Retriever): "page_size": self.page_size, } + if self.cross_languages: + payload["cross_languages"] = self.cross_languages + response = requests.post( f"{self.api_url}/api/v1/retrieval", headers=headers, json=payload ) diff --git a/tests/unit/rag/test_ragflow.py b/tests/unit/rag/test_ragflow.py index b5310ad..42b04de 100644 --- a/tests/unit/rag/test_ragflow.py +++ b/tests/unit/rag/test_ragflow.py @@ -68,6 +68,14 @@ def test_init_page_size(monkeypatch): assert provider.page_size == 5 +def test_init_cross_language(monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + monkeypatch.setenv("RAGFLOW_CROSS_LANGUAGES", "lang1,lang2") + provider = RAGFlowProvider() + assert provider.cross_languages == ["lang1", "lang2"] + + def test_init_missing_env(monkeypatch): monkeypatch.delenv("RAGFLOW_API_URL", raising=False) monkeypatch.setenv("RAGFLOW_API_KEY", "key")