feat: 1. replace black with ruff for fomatting and sort import (#489)

2. use tavily from`langchain-tavily` rather than the older one from `langchain-community`

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
zgjja
2025-08-17 22:57:23 +08:00
committed by GitHub
parent 1bfec3ad05
commit 3b4e993531
62 changed files with 251 additions and 234 deletions

View File

@@ -3,9 +3,9 @@
from langgraph.prebuilt import create_react_agent
from src.prompts import apply_prompt_template
from src.llms.llm import get_llm_by_type
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts import apply_prompt_template
# Create agents using configured LLM types

View File

@@ -1,12 +1,12 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .loader import load_yaml_config
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
from dotenv import load_dotenv
from .loader import load_yaml_config
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
# Load environment variables
load_dotenv()

View File

@@ -8,8 +8,8 @@ from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from src.rag.retriever import Resource
from src.config.report_style import ReportStyle
from src.rag.retriever import Resource
logger = logging.getLogger(__name__)

View File

@@ -2,8 +2,9 @@
# SPDX-License-Identifier: MIT
import os
from typing import Any, Dict
import yaml
from typing import Dict, Any
def replace_env_vars(value: str) -> str:

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import enum
import os
from dotenv import load_dotenv
load_dotenv()

View File

@@ -1,7 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .builder import build_graph_with_memory, build_graph
from .builder import build_graph, build_graph_with_memory
__all__ = [
"build_graph_with_memory",

View File

@@ -1,21 +1,22 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from src.prompts.planner_model import StepType
from .types import State
from .nodes import (
background_investigation_node,
coder_node,
coordinator_node,
human_feedback_node,
planner_node,
reporter_node,
research_team_node,
researcher_node,
coder_node,
human_feedback_node,
background_investigation_node,
)
from .types import State
def continue_to_running_research_team(state: State):

View File

@@ -9,27 +9,26 @@ from typing import Annotated, Literal
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langgraph.types import Command, interrupt
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.types import Command, interrupt
from src.agents import create_agent
from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
get_web_search_tool,
get_retriever_tool,
python_repl_tool,
)
from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type
from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template
from src.tools import (
crawl_tool,
get_retriever_tool,
get_web_search_tool,
python_repl_tool,
)
from src.tools.search import LoggedTavilySearch
from src.utils.json_utils import repair_json_output
from .types import State
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
from .types import State
logger = logging.getLogger(__name__)
@@ -106,7 +105,7 @@ def planner_node(
elif AGENT_LLM_MAP["planner"] == "basic":
llm = get_llm_by_type("basic").with_structured_output(
Plan,
method="json_mode",
# method="json_mode",
)
else:
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])

View File

@@ -1,15 +1,14 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from pathlib import Path
from typing import Any, Dict
import os
import httpx
from pathlib import Path
from typing import Any, Dict, get_args
import httpx
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_deepseek import ChatDeepSeek
from typing import get_args
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from src.config import load_yaml_config
from src.config.agents import LLMType

View File

@@ -8,8 +8,8 @@ from langchain.schema import HumanMessage
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts.template import apply_prompt_template
from src.prompt_enhancer.graph.state import PromptEnhancerState
from src.prompts.template import apply_prompt_template
logger = logging.getLogger(__name__)
@@ -21,7 +21,6 @@ def prompt_enhancer_node(state: PromptEnhancerState):
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
try:
# Create messages with context if provided
context_info = ""
if state.get("context"):

View File

@@ -1,7 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from typing import TypedDict, Optional
from typing import Optional, TypedDict
from src.config.report_style import ReportStyle

View File

@@ -1,11 +1,13 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import dataclasses
import os
from datetime import datetime
from jinja2 import Environment, FileSystemLoader, select_autoescape
from langgraph.prebuilt.chat_agent_executor import AgentState
from src.config.configuration import Configuration
# Initialize Jinja2 environment

View File

@@ -3,6 +3,7 @@
import asyncio
import logging
from langgraph.graph import END, START, StateGraph
from src.prose.graph.prose_continue_node import prose_continue_node

View File

@@ -7,8 +7,8 @@ from langchain.schema import HumanMessage, SystemMessage
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prose.graph.state import ProseState
from src.prompts.template import get_prompt_template
from src.prose.graph.state import ProseState
logger = logging.getLogger(__name__)

View File

@@ -1,10 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .retriever import Retriever, Document, Resource, Chunk
from .ragflow import RAGFlowProvider
from .vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from .builder import build_retriever
from .ragflow import RAGFlowProvider
from .retriever import Chunk, Document, Resource, Retriever
from .vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
__all__ = [
Retriever,

View File

@@ -3,8 +3,8 @@
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
from src.rag.ragflow import RAGFlowProvider
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from src.rag.retriever import Retriever
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
def build_retriever() -> Retriever | None:

View File

@@ -2,11 +2,13 @@
# SPDX-License-Identifier: MIT
import os
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
from typing import List, Optional
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class RAGFlowProvider(Retriever):
"""

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
import abc
from pydantic import BaseModel, Field

View File

@@ -1,16 +1,18 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import requests
import json
import hashlib
import hmac
import json
import os
import urllib.parse
from datetime import datetime
from src.rag.retriever import Chunk, Document, Resource, Retriever
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class VikingDBKnowledgeBaseProvider(Retriever):
"""

View File

@@ -5,8 +5,8 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field
from src.rag.retriever import Resource
from src.config.report_style import ReportStyle
from src.rag.retriever import Resource
class ContentItem(BaseModel):

View File

@@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional
from fastapi import HTTPException
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
logger = logging.getLogger(__name__)

View File

@@ -5,10 +5,11 @@ import logging
from typing import Annotated
from langchain_core.tools import tool
from .decorators import log_io
from src.crawler import Crawler
from .decorators import log_io
logger = logging.getLogger(__name__)

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
import functools
import logging
from typing import Any, Callable, Type, TypeVar
logger = logging.getLogger(__name__)

View File

@@ -4,8 +4,10 @@
import logging
import os
from typing import Annotated, Optional
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL
from .decorators import log_io

View File

@@ -3,15 +3,16 @@
import logging
from typing import List, Optional, Type
from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from src.config.tools import SELECTED_RAG_PROVIDER
from src.rag import Document, Retriever, Resource, build_retriever
from src.rag import Document, Resource, Retriever, build_retriever
logger = logging.getLogger(__name__)
@@ -22,9 +23,7 @@ class RetrieverInput(BaseModel):
class RetrieverTool(BaseTool):
name: str = "local_search_tool"
description: str = (
"Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
)
description: str = "Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords."
args_schema: Type[BaseModel] = RetrieverInput
retriever: Retriever = Field(default_factory=Retriever)

View File

@@ -17,18 +17,16 @@ from langchain_community.utilities import (
WikipediaAPIWrapper,
)
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
from src.config import load_yaml_config
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
from src.config import SELECTED_SEARCH_ENGINE, SearchEngine, load_yaml_config
from src.tools.decorators import create_logged_tool
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchWithImages,
)
logger = logging.getLogger(__name__)
# Create logged versions of the search tools
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
LoggedTavilySearch = create_logged_tool(TavilySearchWithImages)
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
LoggedBraveSearch = create_logged_tool(BraveSearch)
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)

View File

@@ -1,4 +1,4 @@
from .tavily_search_api_wrapper import EnhancedTavilySearchAPIWrapper
from .tavily_search_results_with_images import TavilySearchResultsWithImages
from .tavily_search_results_with_images import TavilySearchWithImages
__all__ = ["EnhancedTavilySearchAPIWrapper", "TavilySearchResultsWithImages"]
__all__ = ["EnhancedTavilySearchAPIWrapper", "TavilySearchWithImages"]

View File

@@ -7,8 +7,8 @@ from typing import Dict, List, Optional
import aiohttp
import requests
from langchain_community.utilities.tavily_search import TAVILY_API_URL
from langchain_community.utilities.tavily_search import (
from langchain_tavily._utilities import TAVILY_API_URL
from langchain_tavily.tavily_search import (
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
)

View File

@@ -1,15 +1,15 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
import json
import logging
from typing import Dict, List, Optional, Tuple, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_community.tools.tavily_search.tool import TavilySearchResults
from langchain_tavily.tavily_search import TavilySearch
from pydantic import Field
from src.tools.tavily_search.tavily_search_api_wrapper import (
@@ -19,7 +19,7 @@ from src.tools.tavily_search.tavily_search_api_wrapper import (
logger = logging.getLogger(__name__)
class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[override, override]
class TavilySearchWithImages(TavilySearch): # type: ignore[override, override]
"""Tool that queries the Tavily Search API and gets back json.
Setup:
@@ -34,9 +34,9 @@ class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[overri
.. code-block:: python
from langchain_community.tools import TavilySearchResults
from langchain_tavily.tavily_search import TavilySearch
tool = TavilySearchResults(
tool = TavilySearch(
max_results=5,
include_answer=True,
include_raw_content=True,
@@ -102,7 +102,9 @@ class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[overri
Default is False.
"""
api_wrapper: EnhancedTavilySearchAPIWrapper = Field(default_factory=EnhancedTavilySearchAPIWrapper) # type: ignore[arg-type]
api_wrapper: EnhancedTavilySearchAPIWrapper = Field(
default_factory=EnhancedTavilySearchAPIWrapper
) # type: ignore[arg-type]
def _run(
self,

View File

@@ -6,10 +6,11 @@ Text-to-Speech module using volcengine TTS API.
"""
import json
import uuid
import logging
import uuid
from typing import Any, Dict, Optional
import requests
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
import json
import logging
from typing import Any
import json_repair

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
import logging
from src.config.configuration import get_recursion_limit
from src.graph import build_graph