feat: 1. replace black with ruff for fomatting and sort import (#489)

2. use tavily from`langchain-tavily` rather than the older one from `langchain-community`

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
zgjja
2025-08-17 22:57:23 +08:00
committed by GitHub
parent 1bfec3ad05
commit 3b4e993531
62 changed files with 251 additions and 234 deletions

View File

@@ -1,21 +1,26 @@
from collections import namedtuple
import json
from collections import namedtuple
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from src.graph.nodes import planner_node
from src.graph.nodes import human_feedback_node
from src.graph.nodes import coordinator_node
from src.graph.nodes import reporter_node
from src.graph.nodes import _execute_agent_step
from src.graph.nodes import _setup_and_execute_agent_step
from src.graph.nodes import researcher_node
from src.graph.nodes import (
_execute_agent_step,
_setup_and_execute_agent_step,
coordinator_node,
human_feedback_node,
planner_node,
reporter_node,
researcher_node,
)
# 在这里 mock 掉 get_llm_by_type避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langgraph.types import Command
from src.graph.nodes import background_investigation_node
from src.config import SearchEngine
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from src.config import SearchEngine
from src.graph.nodes import background_investigation_node
# Mock data

View File

@@ -2,7 +2,8 @@
# SPDX-License-Identifier: MIT
import pytest
from src.prompts.template import get_prompt_template, apply_prompt_template
from src.prompts.template import apply_prompt_template, get_prompt_template
def test_get_prompt_template_success():

View File

@@ -1,9 +1,9 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from unittest.mock import patch, MagicMock
import base64
import json
from unittest.mock import MagicMock, patch
from src.tools.tts import VolcengineTTS

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import sys
import os
import sys
from typing import Annotated
# Import MessagesState directly from langgraph rather than through our application

View File

@@ -3,6 +3,7 @@
import sys
import types
from src.config.configuration import Configuration
# Patch sys.path so relative import works

View File

@@ -3,6 +3,7 @@
import os
import tempfile
from src.config.loader import load_yaml_config, process_dict, replace_env_vars

View File

@@ -1,10 +1,11 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from unittest.mock import MagicMock, patch
import importlib
import sys
from unittest.mock import MagicMock, patch
import pytest
import src.graph.builder as builder_mod

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
import pytest
from src.llms import llm

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from src.prompt_enhancer.graph.builder import build_graph
from src.prompt_enhancer.graph.state import PromptEnhancerState

View File

@@ -1,13 +1,14 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from langchain.schema import HumanMessage, SystemMessage
from src.config.report_style import ReportStyle
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
from src.prompt_enhancer.graph.state import PromptEnhancerState
from src.config.report_style import ReportStyle
@pytest.fixture

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from src.prompt_enhancer.graph.state import PromptEnhancerState
from src.config.report_style import ReportStyle
from src.prompt_enhancer.graph.state import PromptEnhancerState
def test_prompt_enhancer_state_creation():

View File

@@ -1,8 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from src.rag.ragflow import RAGFlowProvider, parse_uri

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
import pytest
from src.rag.retriever import Chunk, Document, Resource, Retriever

View File

@@ -1,13 +1,15 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import pytest
import json
import hashlib
import hmac
from unittest.mock import patch, MagicMock
import json
import os
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider, parse_uri

View File

@@ -3,15 +3,16 @@
import base64
import os
from unittest.mock import MagicMock, patch, mock_open
from unittest.mock import MagicMock, mock_open, patch
import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException
from src.server.app import app, _make_event, _astream_workflow_generator
from src.config.report_style import ReportStyle
from fastapi.testclient import TestClient
from langchain_core.messages import AIMessageChunk, ToolMessage
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from langchain_core.messages import AIMessageChunk
from src.config.report_style import ReportStyle
from src.server.app import _astream_workflow_generator, _make_event, app
@pytest.fixture
@@ -333,7 +334,6 @@ class TestMCPEndpoint:
def test_mcp_server_metadata_without_enable_configuration(
self, mock_load_tools, client
):
request_data = {
"transport": "stdio",
"command": "test_command",
@@ -547,7 +547,6 @@ class TestAstreamWorkflowGenerator:
@pytest.mark.asyncio
@patch("src.server.app.graph")
async def test_astream_workflow_generator_with_interrupt_feedback(self, mock_graph):
# Mock the async stream
async def mock_astream(*args, **kwargs):
# Verify that Command is passed as input when interrupt_feedback is provided
@@ -620,7 +619,6 @@ class TestAstreamWorkflowGenerator:
@pytest.mark.asyncio
@patch("src.server.app.graph")
async def test_astream_workflow_generator_tool_message(self, mock_graph):
# Mock tool message
mock_tool_message = ToolMessage(content="Tool result", tool_call_id="tool_123")
mock_tool_message.id = "msg_456"
@@ -659,7 +657,6 @@ class TestAstreamWorkflowGenerator:
async def test_astream_workflow_generator_ai_message_with_tool_calls(
self, mock_graph
):
# Mock AI message with tool calls
mock_ai_message = AIMessageChunk(content="Making tool call")
mock_ai_message.id = "msg_789"
@@ -701,7 +698,6 @@ class TestAstreamWorkflowGenerator:
async def test_astream_workflow_generator_ai_message_with_tool_call_chunks(
self, mock_graph
):
# Mock AI message with only tool call chunks
mock_ai_message = AIMessageChunk(content="Streaming tool call")
mock_ai_message.id = "msg_101"
@@ -740,7 +736,6 @@ class TestAstreamWorkflowGenerator:
@pytest.mark.asyncio
@patch("src.server.app.graph")
async def test_astream_workflow_generator_with_finish_reason(self, mock_graph):
# Mock AI message with finish reason
mock_ai_message = AIMessageChunk(content="Complete response")
mock_ai_message.id = "msg_finish"
@@ -780,7 +775,6 @@ class TestAstreamWorkflowGenerator:
@pytest.mark.asyncio
@patch("src.server.app.graph")
async def test_astream_workflow_generator_config_passed_correctly(self, mock_graph):
mock_ai_message = AIMessageChunk(content="Test")
mock_ai_message.id = "test_id"
mock_ai_message.response_metadata = {}

View File

@@ -1,24 +1,25 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from pydantic import ValidationError
import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test
from src.config.report_style import ReportStyle
from src.rag.retriever import Resource
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import HTTPException
from src.server.chat_request import (
ContentItem,
ChatMessage,
ChatRequest,
TTSRequest,
ContentItem,
EnhancePromptRequest,
GeneratePodcastRequest,
GeneratePPTRequest,
GenerateProseRequest,
EnhancePromptRequest,
TTSRequest,
)
import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test
def test_content_item_text_and_image():

View File

@@ -3,6 +3,7 @@
import pytest
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import HTTPException
import src.server.mcp_utils as mcp_utils

View File

@@ -1,9 +1,9 @@
from unittest.mock import Mock, patch
from src.tools.crawl import crawl_tool
class TestCrawlTool:
@patch("src.tools.crawl.Crawler")
def test_crawl_tool_success(self, mock_crawler_class):
# Arrange

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, call, patch
from src.tools.decorators import create_logged_tool
@@ -13,7 +14,6 @@ class MockBaseTool:
class TestLoggedToolMixin:
def test_run_calls_log_operation(self):
"""Test that _run calls _log_operation with correct parameters."""
# Create a logged tool instance

View File

@@ -2,13 +2,14 @@
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import patch
import pytest
from src.tools.python_repl import python_repl_tool
class TestPythonReplTool:
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")

View File

@@ -2,14 +2,15 @@
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import patch
from src.tools.search import get_web_search_tool
import pytest
from src.config import SearchEngine
from src.tools.search import get_web_search_tool
class TestGetWebSearchTool:
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
def test_get_web_search_tool_tavily(self):
tool = get_web_search_tool(max_search_results=5)

View File

@@ -1,16 +1,17 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
import requests
from src.tools.tavily_search.tavily_search_api_wrapper import (
EnhancedTavilySearchAPIWrapper,
)
class TestEnhancedTavilySearchAPIWrapper:
@pytest.fixture
def wrapper(self):
with patch(

View File

@@ -1,18 +1,19 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
from src.tools.tavily_search.tavily_search_api_wrapper import (
EnhancedTavilySearchAPIWrapper,
)
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchWithImages,
)
class TestTavilySearchResultsWithImages:
class TestTavilySearchWithImages:
@pytest.fixture
def mock_api_wrapper(self):
"""Create a mock API wrapper."""
@@ -21,8 +22,8 @@ class TestTavilySearchResultsWithImages:
@pytest.fixture
def search_tool(self, mock_api_wrapper):
"""Create a TavilySearchResultsWithImages instance with mocked dependencies."""
tool = TavilySearchResultsWithImages(
"""Create a TavilySearchWithImages instance with mocked dependencies."""
tool = TavilySearchWithImages(
max_results=5,
include_answer=True,
include_raw_content=True,
@@ -64,15 +65,13 @@ class TestTavilySearchResultsWithImages:
def test_init_default_values(self):
"""Test initialization with default values."""
tool = TavilySearchResultsWithImages()
tool = TavilySearchWithImages()
assert tool.include_image_descriptions is False
assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper)
def test_init_custom_values(self):
"""Test initialization with custom values."""
tool = TavilySearchResultsWithImages(
max_results=10, include_image_descriptions=True
)
tool = TavilySearchWithImages(max_results=10, include_image_descriptions=True)
assert tool.max_results == 10
assert tool.include_image_descriptions is True

View File

@@ -2,13 +2,15 @@
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch
from langchain_core.callbacks import (
CallbackManagerForToolRun,
AsyncCallbackManagerForToolRun,
)
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from src.rag import Chunk, Document, Resource, Retriever
from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool
from src.rag import Document, Retriever, Resource, Chunk
def test_retriever_input_model():

View File

@@ -2,11 +2,11 @@
# SPDX-License-Identifier: MIT
import json
from src.utils.json_utils import repair_json_output
class TestRepairJsonOutput:
def test_valid_json_object(self):
"""Test with valid JSON object"""
content = '{"key": "value", "number": 123}'