mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 04:14:46 +08:00
feat: implement enhance prompt (#294)
* feat: implement enhance prompt * add unit test * fix prompt * fix: fix eslint and compiling issues * feat: add border-beam animation * fix: fix importing issues --------- Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
2
tests/unit/prompt_enhancer/__init__.py
Normal file
2
tests/unit/prompt_enhancer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
2
tests/unit/prompt_enhancer/graph/__init__.py
Normal file
2
tests/unit/prompt_enhancer/graph/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
156
tests/unit/prompt_enhancer/graph/test_builder.py
Normal file
156
tests/unit/prompt_enhancer/graph/test_builder.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.prompt_enhancer.graph.builder import build_graph
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class TestBuildGraph:
|
||||
"""Test cases for build_graph function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_structure(self, mock_state_graph):
|
||||
"""Test that build_graph creates the correct graph structure."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify StateGraph was created with correct state type
|
||||
mock_state_graph.assert_called_once_with(PromptEnhancerState)
|
||||
|
||||
# Verify entry point was set
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify finish point was set
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify graph was compiled
|
||||
mock_builder.compile.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
@patch("src.prompt_enhancer.graph.builder.prompt_enhancer_node")
|
||||
def test_build_graph_node_function(self, mock_enhancer_node, mock_state_graph):
|
||||
"""Test that the correct node function is added to the graph."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify the correct node function was added
|
||||
mock_builder.add_node.assert_called_once_with("enhancer", mock_enhancer_node)
|
||||
|
||||
def test_build_graph_returns_compiled_graph(self):
|
||||
"""Test that build_graph returns a compiled graph object."""
|
||||
with patch("src.prompt_enhancer.graph.builder.StateGraph") as mock_state_graph:
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
assert result is mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_call_sequence(self, mock_state_graph):
|
||||
"""Test that build_graph calls methods in the correct sequence."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
# Track call order
|
||||
call_order = []
|
||||
|
||||
def track_add_node(*args, **kwargs):
|
||||
call_order.append("add_node")
|
||||
|
||||
def track_set_entry_point(*args, **kwargs):
|
||||
call_order.append("set_entry_point")
|
||||
|
||||
def track_set_finish_point(*args, **kwargs):
|
||||
call_order.append("set_finish_point")
|
||||
|
||||
def track_compile(*args, **kwargs):
|
||||
call_order.append("compile")
|
||||
return mock_compiled_graph
|
||||
|
||||
mock_builder.add_node.side_effect = track_add_node
|
||||
mock_builder.set_entry_point.side_effect = track_set_entry_point
|
||||
mock_builder.set_finish_point.side_effect = track_set_finish_point
|
||||
mock_builder.compile.side_effect = track_compile
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify the correct call sequence
|
||||
expected_order = ["add_node", "set_entry_point", "set_finish_point", "compile"]
|
||||
assert call_order == expected_order
|
||||
|
||||
def test_build_graph_integration(self):
|
||||
"""Integration test to verify the graph can be built without mocking."""
|
||||
# This test verifies that all imports and dependencies are correct
|
||||
try:
|
||||
graph = build_graph()
|
||||
assert graph is not None
|
||||
# The graph should be a compiled LangGraph object
|
||||
assert hasattr(graph, "invoke") or hasattr(graph, "stream")
|
||||
except ImportError as e:
|
||||
pytest.skip(f"Skipping integration test due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
# If there are configuration issues (like missing LLM config),
|
||||
# we still consider the test successful if the graph structure is built
|
||||
if "LLM" in str(e) or "configuration" in str(e).lower():
|
||||
pytest.skip(
|
||||
f"Skipping integration test due to configuration issues: {e}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_single_node_workflow(self, mock_state_graph):
|
||||
"""Test that the graph is configured as a single-node workflow."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify only one node is added
|
||||
assert mock_builder.add_node.call_count == 1
|
||||
|
||||
# Verify entry and finish points are the same node
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_state_type(self, mock_state_graph):
|
||||
"""Test that the graph is initialized with the correct state type."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify StateGraph was initialized with PromptEnhancerState
|
||||
args, kwargs = mock_state_graph.call_args
|
||||
assert args[0] == PromptEnhancerState
|
||||
219
tests/unit/prompt_enhancer/graph/test_enhancer_node.py
Normal file
219
tests/unit/prompt_enhancer/graph/test_enhancer_node.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
"""Mock LLM that returns a test response."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="System prompt template"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestPromptEnhancerNode:
|
||||
"""Test cases for prompt_enhancer_node function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_basic_prompt_enhancement(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test basic prompt enhancement without context or report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Write about AI")
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
mock_get_llm.assert_called_once_with("basic")
|
||||
mock_llm.invoke.assert_called_once_with(mock_messages)
|
||||
|
||||
# Verify apply_prompt_template was called correctly
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert "messages" in call_args[0][1]
|
||||
assert "report_style" in call_args[0][1]
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_report_style(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", report_style=ReportStyle.ACADEMIC
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called with report_style
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert call_args[0][1]["report_style"] == ReportStyle.ACADEMIC
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_context(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with additional context."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", context="Focus on machine learning applications"
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
|
||||
# Check that the context was included in the human message
|
||||
messages_arg = call_args[0][1]["messages"]
|
||||
assert len(messages_arg) == 1
|
||||
human_message = messages_arg[0]
|
||||
assert isinstance(human_message, HumanMessage)
|
||||
assert "Focus on machine learning applications" in human_message.content
|
||||
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when LLM call fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM to raise an exception
|
||||
mock_llm.invoke.side_effect = Exception("LLM error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_template_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when template application fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Mock apply_prompt_template to raise an exception
|
||||
mock_apply_template.side_effect = Exception("Template error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that common prefixes are removed from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Test different prefixes that should be removed
|
||||
test_cases = [
|
||||
"Enhanced Prompt: This is the enhanced prompt",
|
||||
"Enhanced prompt: This is the enhanced prompt",
|
||||
"Here's the enhanced prompt: This is the enhanced prompt",
|
||||
"Here is the enhanced prompt: This is the enhanced prompt",
|
||||
"**Enhanced Prompt**: This is the enhanced prompt",
|
||||
"**Enhanced prompt**: This is the enhanced prompt",
|
||||
]
|
||||
|
||||
for response_with_prefix in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_with_prefix)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "This is the enhanced prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_whitespace_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that whitespace is properly stripped from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM response with extra whitespace
|
||||
mock_llm.invoke.return_value = MagicMock(
|
||||
content=" \n\n Enhanced prompt \n\n "
|
||||
)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt"}
|
||||
108
tests/unit/prompt_enhancer/graph/test_state.py
Normal file
108
tests/unit/prompt_enhancer/graph/test_state.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_creation():
|
||||
"""Test that PromptEnhancerState can be created with required fields."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt", context=None, report_style=None, output=None
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Test prompt"
|
||||
assert state["context"] is None
|
||||
assert state["report_style"] is None
|
||||
assert state["output"] is None
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_all_fields():
|
||||
"""Test PromptEnhancerState with all fields populated."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI",
|
||||
context="Additional context about AI research",
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
output="Enhanced prompt about AI research",
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Write about AI"
|
||||
assert state["context"] == "Additional context about AI research"
|
||||
assert state["report_style"] == ReportStyle.ACADEMIC
|
||||
assert state["output"] == "Enhanced prompt about AI research"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_minimal():
|
||||
"""Test PromptEnhancerState with only required prompt field."""
|
||||
state = PromptEnhancerState(prompt="Minimal prompt")
|
||||
|
||||
assert state["prompt"] == "Minimal prompt"
|
||||
# Optional fields should not be present if not specified
|
||||
assert "context" not in state
|
||||
assert "report_style" not in state
|
||||
assert "output" not in state
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_different_report_styles():
|
||||
"""Test PromptEnhancerState with different ReportStyle values."""
|
||||
styles = [
|
||||
ReportStyle.ACADEMIC,
|
||||
ReportStyle.POPULAR_SCIENCE,
|
||||
ReportStyle.NEWS,
|
||||
ReportStyle.SOCIAL_MEDIA,
|
||||
]
|
||||
|
||||
for style in styles:
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=style)
|
||||
assert state["report_style"] == style
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_update():
|
||||
"""Test updating PromptEnhancerState fields."""
|
||||
state = PromptEnhancerState(prompt="Original prompt")
|
||||
|
||||
# Update with new fields
|
||||
state.update(
|
||||
{
|
||||
"context": "New context",
|
||||
"report_style": ReportStyle.NEWS,
|
||||
"output": "Enhanced output",
|
||||
}
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Original prompt"
|
||||
assert state["context"] == "New context"
|
||||
assert state["report_style"] == ReportStyle.NEWS
|
||||
assert state["output"] == "Enhanced output"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_get_method():
|
||||
"""Test using get() method on PromptEnhancerState."""
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=ReportStyle.ACADEMIC)
|
||||
|
||||
# Test get with existing keys
|
||||
assert state.get("prompt") == "Test prompt"
|
||||
assert state.get("report_style") == ReportStyle.ACADEMIC
|
||||
|
||||
# Test get with non-existing keys
|
||||
assert state.get("context") is None
|
||||
assert state.get("output") is None
|
||||
assert state.get("nonexistent", "default") == "default"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_type_annotations():
|
||||
"""Test that the state accepts correct types."""
|
||||
# This test ensures the TypedDict structure is working correctly
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt",
|
||||
context="Test context",
|
||||
report_style=ReportStyle.POPULAR_SCIENCE,
|
||||
output="Test output",
|
||||
)
|
||||
|
||||
# Verify types
|
||||
assert isinstance(state["prompt"], str)
|
||||
assert isinstance(state["context"], str)
|
||||
assert isinstance(state["report_style"], ReportStyle)
|
||||
assert isinstance(state["output"], str)
|
||||
Reference in New Issue
Block a user