mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-17 03:34:45 +08:00
feat: 1. replace black with ruff for fomatting and sort import (#489)
2. use tavily from`langchain-tavily` rather than the older one from `langchain-community` Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,21 +1,26 @@
|
||||
from collections import namedtuple
|
||||
import json
|
||||
from collections import namedtuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.graph.nodes import planner_node
|
||||
from src.graph.nodes import human_feedback_node
|
||||
from src.graph.nodes import coordinator_node
|
||||
from src.graph.nodes import reporter_node
|
||||
from src.graph.nodes import _execute_agent_step
|
||||
from src.graph.nodes import _setup_and_execute_agent_step
|
||||
from src.graph.nodes import researcher_node
|
||||
|
||||
from src.graph.nodes import (
|
||||
_execute_agent_step,
|
||||
_setup_and_execute_agent_step,
|
||||
coordinator_node,
|
||||
human_feedback_node,
|
||||
planner_node,
|
||||
reporter_node,
|
||||
researcher_node,
|
||||
)
|
||||
|
||||
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
|
||||
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
|
||||
from langgraph.types import Command
|
||||
from src.graph.nodes import background_investigation_node
|
||||
from src.config import SearchEngine
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from src.config import SearchEngine
|
||||
from src.graph.nodes import background_investigation_node
|
||||
|
||||
|
||||
# Mock data
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from src.prompts.template import get_prompt_template, apply_prompt_template
|
||||
|
||||
from src.prompts.template import apply_prompt_template, get_prompt_template
|
||||
|
||||
|
||||
def test_get_prompt_template_success():
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from src.tools.tts import VolcengineTTS
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
from typing import Annotated
|
||||
|
||||
# Import MessagesState directly from langgraph rather than through our application
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
from src.config.configuration import Configuration
|
||||
|
||||
# Patch sys.path so relative import works
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from src.config.loader import load_yaml_config, process_dict, replace_env_vars
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import src.graph.builder as builder_mod
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
|
||||
from src.llms import llm
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.prompt_enhancer.graph.builder import build_graph
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_creation():
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.rag.ragflow import RAGFlowProvider, parse_uri
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import json
|
||||
import hashlib
|
||||
import hmac
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider, parse_uri
|
||||
|
||||
|
||||
|
||||
@@ -3,15 +3,16 @@
|
||||
|
||||
import base64
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import HTTPException
|
||||
from src.server.app import app, _make_event, _astream_workflow_generator
|
||||
from src.config.report_style import ReportStyle
|
||||
from fastapi.testclient import TestClient
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.server.app import _astream_workflow_generator, _make_event, app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -333,7 +334,6 @@ class TestMCPEndpoint:
|
||||
def test_mcp_server_metadata_without_enable_configuration(
|
||||
self, mock_load_tools, client
|
||||
):
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
@@ -547,7 +547,6 @@ class TestAstreamWorkflowGenerator:
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_workflow_generator_with_interrupt_feedback(self, mock_graph):
|
||||
|
||||
# Mock the async stream
|
||||
async def mock_astream(*args, **kwargs):
|
||||
# Verify that Command is passed as input when interrupt_feedback is provided
|
||||
@@ -620,7 +619,6 @@ class TestAstreamWorkflowGenerator:
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_workflow_generator_tool_message(self, mock_graph):
|
||||
|
||||
# Mock tool message
|
||||
mock_tool_message = ToolMessage(content="Tool result", tool_call_id="tool_123")
|
||||
mock_tool_message.id = "msg_456"
|
||||
@@ -659,7 +657,6 @@ class TestAstreamWorkflowGenerator:
|
||||
async def test_astream_workflow_generator_ai_message_with_tool_calls(
|
||||
self, mock_graph
|
||||
):
|
||||
|
||||
# Mock AI message with tool calls
|
||||
mock_ai_message = AIMessageChunk(content="Making tool call")
|
||||
mock_ai_message.id = "msg_789"
|
||||
@@ -701,7 +698,6 @@ class TestAstreamWorkflowGenerator:
|
||||
async def test_astream_workflow_generator_ai_message_with_tool_call_chunks(
|
||||
self, mock_graph
|
||||
):
|
||||
|
||||
# Mock AI message with only tool call chunks
|
||||
mock_ai_message = AIMessageChunk(content="Streaming tool call")
|
||||
mock_ai_message.id = "msg_101"
|
||||
@@ -740,7 +736,6 @@ class TestAstreamWorkflowGenerator:
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_workflow_generator_with_finish_reason(self, mock_graph):
|
||||
|
||||
# Mock AI message with finish reason
|
||||
mock_ai_message = AIMessageChunk(content="Complete response")
|
||||
mock_ai_message.id = "msg_finish"
|
||||
@@ -780,7 +775,6 @@ class TestAstreamWorkflowGenerator:
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_workflow_generator_config_passed_correctly(self, mock_graph):
|
||||
|
||||
mock_ai_message = AIMessageChunk(content="Test")
|
||||
mock_ai_message.id = "test_id"
|
||||
mock_ai_message.response_metadata = {}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.rag.retriever import Resource
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.server.chat_request import (
|
||||
ContentItem,
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
TTSRequest,
|
||||
ContentItem,
|
||||
EnhancePromptRequest,
|
||||
GeneratePodcastRequest,
|
||||
GeneratePPTRequest,
|
||||
GenerateProseRequest,
|
||||
EnhancePromptRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test
|
||||
|
||||
|
||||
def test_content_item_text_and_image():
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi import HTTPException
|
||||
|
||||
import src.server.mcp_utils as mcp_utils
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from src.tools.crawl import crawl_tool
|
||||
|
||||
|
||||
class TestCrawlTool:
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
def test_crawl_tool_success(self, mock_crawler_class):
|
||||
# Arrange
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
from src.tools.decorators import create_logged_tool
|
||||
|
||||
|
||||
@@ -13,7 +14,6 @@ class MockBaseTool:
|
||||
|
||||
|
||||
class TestLoggedToolMixin:
|
||||
|
||||
def test_run_calls_log_operation(self):
|
||||
"""Test that _run calls _log_operation with correct parameters."""
|
||||
# Create a logged tool instance
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.tools.python_repl import python_repl_tool
|
||||
|
||||
|
||||
class TestPythonReplTool:
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from src.tools.search import get_web_search_tool
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import SearchEngine
|
||||
from src.tools.search import get_web_search_tool
|
||||
|
||||
|
||||
class TestGetWebSearchTool:
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
def test_get_web_search_tool_tavily(self):
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
import requests
|
||||
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
class TestEnhancedTavilySearchAPIWrapper:
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
with patch(
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchResultsWithImages,
|
||||
)
|
||||
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchWithImages,
|
||||
)
|
||||
|
||||
|
||||
class TestTavilySearchResultsWithImages:
|
||||
|
||||
class TestTavilySearchWithImages:
|
||||
@pytest.fixture
|
||||
def mock_api_wrapper(self):
|
||||
"""Create a mock API wrapper."""
|
||||
@@ -21,8 +22,8 @@ class TestTavilySearchResultsWithImages:
|
||||
|
||||
@pytest.fixture
|
||||
def search_tool(self, mock_api_wrapper):
|
||||
"""Create a TavilySearchResultsWithImages instance with mocked dependencies."""
|
||||
tool = TavilySearchResultsWithImages(
|
||||
"""Create a TavilySearchWithImages instance with mocked dependencies."""
|
||||
tool = TavilySearchWithImages(
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
@@ -64,15 +65,13 @@ class TestTavilySearchResultsWithImages:
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""Test initialization with default values."""
|
||||
tool = TavilySearchResultsWithImages()
|
||||
tool = TavilySearchWithImages()
|
||||
assert tool.include_image_descriptions is False
|
||||
assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper)
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""Test initialization with custom values."""
|
||||
tool = TavilySearchResultsWithImages(
|
||||
max_results=10, include_image_descriptions=True
|
||||
)
|
||||
tool = TavilySearchWithImages(max_results=10, include_image_descriptions=True)
|
||||
assert tool.max_results == 10
|
||||
assert tool.include_image_descriptions is True
|
||||
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
|
||||
from src.rag import Chunk, Document, Resource, Retriever
|
||||
from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool
|
||||
from src.rag import Document, Retriever, Resource, Chunk
|
||||
|
||||
|
||||
def test_retriever_input_model():
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
|
||||
from src.utils.json_utils import repair_json_output
|
||||
|
||||
|
||||
class TestRepairJsonOutput:
|
||||
|
||||
def test_valid_json_object(self):
|
||||
"""Test with valid JSON object"""
|
||||
content = '{"key": "value", "number": 123}'
|
||||
|
||||
Reference in New Issue
Block a user