mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-12 01:54:45 +08:00
feat: 1. replace black with ruff for fomatting and sort import (#489)
2. use tavily from`langchain-tavily` rather than the older one from `langchain-community` Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -3,9 +3,9 @@
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from src.prompts import apply_prompt_template
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts import apply_prompt_template
|
||||
|
||||
|
||||
# Create agents using configured LLM types
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .loader import load_yaml_config
|
||||
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .loader import load_yaml_config
|
||||
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
|
||||
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from typing import Any, Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.rag.retriever import Resource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def replace_env_vars(value: str) -> str:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import enum
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .builder import build_graph_with_memory, build_graph
|
||||
from .builder import build_graph, build_graph_with_memory
|
||||
|
||||
__all__ = [
|
||||
"build_graph_with_memory",
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from src.prompts.planner_model import StepType
|
||||
|
||||
from .types import State
|
||||
from .nodes import (
|
||||
background_investigation_node,
|
||||
coder_node,
|
||||
coordinator_node,
|
||||
human_feedback_node,
|
||||
planner_node,
|
||||
reporter_node,
|
||||
research_team_node,
|
||||
researcher_node,
|
||||
coder_node,
|
||||
human_feedback_node,
|
||||
background_investigation_node,
|
||||
)
|
||||
from .types import State
|
||||
|
||||
|
||||
def continue_to_running_research_team(state: State):
|
||||
|
||||
@@ -9,27 +9,26 @@ from typing import Annotated, Literal
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command, interrupt
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langgraph.types import Command, interrupt
|
||||
|
||||
from src.agents import create_agent
|
||||
from src.tools.search import LoggedTavilySearch
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
get_web_search_tool,
|
||||
get_retriever_tool,
|
||||
python_repl_tool,
|
||||
)
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.config.configuration import Configuration
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts.planner_model import Plan
|
||||
from src.prompts.template import apply_prompt_template
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
get_retriever_tool,
|
||||
get_web_search_tool,
|
||||
python_repl_tool,
|
||||
)
|
||||
from src.tools.search import LoggedTavilySearch
|
||||
from src.utils.json_utils import repair_json_output
|
||||
|
||||
from .types import State
|
||||
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from .types import State
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,7 +105,7 @@ def planner_node(
|
||||
elif AGENT_LLM_MAP["planner"] == "basic":
|
||||
llm = get_llm_by_type("basic").with_structured_output(
|
||||
Plan,
|
||||
method="json_mode",
|
||||
# method="json_mode",
|
||||
)
|
||||
else:
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
import os
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, get_args
|
||||
|
||||
import httpx
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from typing import get_args
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
from src.config import load_yaml_config
|
||||
from src.config.agents import LLMType
|
||||
|
||||
@@ -8,8 +8,8 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts.template import apply_prompt_template
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.prompts.template import apply_prompt_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,7 +21,6 @@ def prompt_enhancer_node(state: PromptEnhancerState):
|
||||
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
|
||||
|
||||
try:
|
||||
|
||||
# Create messages with context if provided
|
||||
context_info = ""
|
||||
if state.get("context"):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import dataclasses
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
|
||||
from src.config.configuration import Configuration
|
||||
|
||||
# Initialize Jinja2 environment
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from src.prose.graph.prose_continue_node import prose_continue_node
|
||||
|
||||
@@ -7,8 +7,8 @@ from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prose.graph.state import ProseState
|
||||
from src.prompts.template import get_prompt_template
|
||||
from src.prose.graph.state import ProseState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .retriever import Retriever, Document, Resource, Chunk
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
from .builder import build_retriever
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .retriever import Chunk, Document, Resource, Retriever
|
||||
from .vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
|
||||
__all__ = [
|
||||
Retriever,
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
|
||||
from src.rag.ragflow import RAGFlowProvider
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
from src.rag.retriever import Retriever
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
|
||||
|
||||
def build_retriever() -> Retriever | None:
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import requests
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
class RAGFlowProvider(Retriever):
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import abc
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
"""
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.rag.retriever import Resource
|
||||
|
||||
|
||||
class ContentItem(BaseModel):
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -5,10 +5,11 @@ import logging
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from .decorators import log_io
|
||||
|
||||
from src.crawler import Crawler
|
||||
|
||||
from .decorators import log_io
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Type, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langchain_experimental.utilities import PythonREPL
|
||||
|
||||
from .decorators import log_io
|
||||
|
||||
|
||||
|
||||
@@ -3,15 +3,16 @@
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Type
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.rag import Document, Retriever, Resource, build_retriever
|
||||
from src.rag import Document, Resource, Retriever, build_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,9 +23,7 @@ class RetrieverInput(BaseModel):
|
||||
|
||||
class RetrieverTool(BaseTool):
|
||||
name: str = "local_search_tool"
|
||||
description: str = (
|
||||
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
|
||||
)
|
||||
description: str = "Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
|
||||
args_schema: Type[BaseModel] = RetrieverInput
|
||||
|
||||
retriever: Retriever = Field(default_factory=Retriever)
|
||||
|
||||
@@ -17,18 +17,16 @@ from langchain_community.utilities import (
|
||||
WikipediaAPIWrapper,
|
||||
)
|
||||
|
||||
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
|
||||
from src.config import load_yaml_config
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchResultsWithImages,
|
||||
)
|
||||
|
||||
from src.config import SELECTED_SEARCH_ENGINE, SearchEngine, load_yaml_config
|
||||
from src.tools.decorators import create_logged_tool
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchWithImages,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create logged versions of the search tools
|
||||
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
|
||||
LoggedTavilySearch = create_logged_tool(TavilySearchWithImages)
|
||||
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
|
||||
LoggedBraveSearch = create_logged_tool(BraveSearch)
|
||||
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .tavily_search_api_wrapper import EnhancedTavilySearchAPIWrapper
|
||||
from .tavily_search_results_with_images import TavilySearchResultsWithImages
|
||||
from .tavily_search_results_with_images import TavilySearchWithImages
|
||||
|
||||
__all__ = ["EnhancedTavilySearchAPIWrapper", "TavilySearchResultsWithImages"]
|
||||
__all__ = ["EnhancedTavilySearchAPIWrapper", "TavilySearchWithImages"]
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_community.utilities.tavily_search import TAVILY_API_URL
|
||||
from langchain_community.utilities.tavily_search import (
|
||||
from langchain_tavily._utilities import TAVILY_API_URL
|
||||
from langchain_tavily.tavily_search import (
|
||||
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_community.tools.tavily_search.tool import TavilySearchResults
|
||||
from langchain_tavily.tavily_search import TavilySearch
|
||||
from pydantic import Field
|
||||
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
@@ -19,7 +19,7 @@ from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[override, override]
|
||||
class TavilySearchWithImages(TavilySearch): # type: ignore[override, override]
|
||||
"""Tool that queries the Tavily Search API and gets back json.
|
||||
|
||||
Setup:
|
||||
@@ -34,9 +34,9 @@ class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[overri
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.tools import TavilySearchResults
|
||||
from langchain_tavily.tavily_search import TavilySearch
|
||||
|
||||
tool = TavilySearchResults(
|
||||
tool = TavilySearch(
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
@@ -102,7 +102,9 @@ class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[overri
|
||||
Default is False.
|
||||
"""
|
||||
|
||||
api_wrapper: EnhancedTavilySearchAPIWrapper = Field(default_factory=EnhancedTavilySearchAPIWrapper) # type: ignore[arg-type]
|
||||
api_wrapper: EnhancedTavilySearchAPIWrapper = Field(
|
||||
default_factory=EnhancedTavilySearchAPIWrapper
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
||||
@@ -6,10 +6,11 @@ Text-to-Speech module using volcengine TTS API.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import json_repair
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
|
||||
from src.config.configuration import get_recursion_limit
|
||||
from src.graph import build_graph
|
||||
|
||||
|
||||
Reference in New Issue
Block a user