mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-05 23:10:22 +08:00
* feat: implement enhance prompt * add unit test * fix prompt * fix: fix eslint and compiling issues * feat: add border-beam animation * fix: fix importing issues --------- Co-authored-by: Henry Li <henry1943@163.com>
220 lines
8.2 KiB
Python
220 lines
8.2 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
from langchain.schema import HumanMessage, SystemMessage
|
|
|
|
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
|
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
|
from src.config.report_style import ReportStyle
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm():
|
|
"""Mock LLM that returns a test response."""
|
|
llm = MagicMock()
|
|
llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
|
|
return llm
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_messages():
|
|
"""Mock messages returned by apply_prompt_template."""
|
|
return [
|
|
SystemMessage(content="System prompt template"),
|
|
HumanMessage(content="Test human message"),
|
|
]
|
|
|
|
|
|
class TestPromptEnhancerNode:
|
|
"""Test cases for prompt_enhancer_node function."""
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_basic_prompt_enhancement(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test basic prompt enhancement without context or report style."""
|
|
mock_get_llm.return_value = mock_llm
|
|
mock_apply_template.return_value = mock_messages
|
|
|
|
state = PromptEnhancerState(prompt="Write about AI")
|
|
|
|
result = prompt_enhancer_node(state)
|
|
|
|
# Verify LLM was called
|
|
mock_get_llm.assert_called_once_with("basic")
|
|
mock_llm.invoke.assert_called_once_with(mock_messages)
|
|
|
|
# Verify apply_prompt_template was called correctly
|
|
mock_apply_template.assert_called_once()
|
|
call_args = mock_apply_template.call_args
|
|
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
|
assert "messages" in call_args[0][1]
|
|
assert "report_style" in call_args[0][1]
|
|
|
|
# Verify result
|
|
assert result == {"output": "Enhanced test prompt"}
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_prompt_enhancement_with_report_style(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test prompt enhancement with report style."""
|
|
mock_get_llm.return_value = mock_llm
|
|
mock_apply_template.return_value = mock_messages
|
|
|
|
state = PromptEnhancerState(
|
|
prompt="Write about AI", report_style=ReportStyle.ACADEMIC
|
|
)
|
|
|
|
result = prompt_enhancer_node(state)
|
|
|
|
# Verify apply_prompt_template was called with report_style
|
|
mock_apply_template.assert_called_once()
|
|
call_args = mock_apply_template.call_args
|
|
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
|
assert call_args[0][1]["report_style"] == ReportStyle.ACADEMIC
|
|
|
|
# Verify result
|
|
assert result == {"output": "Enhanced test prompt"}
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_prompt_enhancement_with_context(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test prompt enhancement with additional context."""
|
|
mock_get_llm.return_value = mock_llm
|
|
mock_apply_template.return_value = mock_messages
|
|
|
|
state = PromptEnhancerState(
|
|
prompt="Write about AI", context="Focus on machine learning applications"
|
|
)
|
|
|
|
result = prompt_enhancer_node(state)
|
|
|
|
# Verify apply_prompt_template was called
|
|
mock_apply_template.assert_called_once()
|
|
call_args = mock_apply_template.call_args
|
|
|
|
# Check that the context was included in the human message
|
|
messages_arg = call_args[0][1]["messages"]
|
|
assert len(messages_arg) == 1
|
|
human_message = messages_arg[0]
|
|
assert isinstance(human_message, HumanMessage)
|
|
assert "Focus on machine learning applications" in human_message.content
|
|
|
|
assert result == {"output": "Enhanced test prompt"}
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_error_handling(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test error handling when LLM call fails."""
|
|
mock_get_llm.return_value = mock_llm
|
|
mock_apply_template.return_value = mock_messages
|
|
|
|
# Mock LLM to raise an exception
|
|
mock_llm.invoke.side_effect = Exception("LLM error")
|
|
|
|
state = PromptEnhancerState(prompt="Test prompt")
|
|
result = prompt_enhancer_node(state)
|
|
|
|
# Should return original prompt on error
|
|
assert result == {"output": "Test prompt"}
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_template_error_handling(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test error handling when template application fails."""
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
# Mock apply_prompt_template to raise an exception
|
|
mock_apply_template.side_effect = Exception("Template error")
|
|
|
|
state = PromptEnhancerState(prompt="Test prompt")
|
|
result = prompt_enhancer_node(state)
|
|
|
|
# Should return original prompt on error
|
|
assert result == {"output": "Test prompt"}
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_prefix_removal(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test that common prefixes are removed from LLM response."""
|
|
mock_get_llm.return_value = mock_llm
|
|
mock_apply_template.return_value = mock_messages
|
|
|
|
# Test different prefixes that should be removed
|
|
test_cases = [
|
|
"Enhanced Prompt: This is the enhanced prompt",
|
|
"Enhanced prompt: This is the enhanced prompt",
|
|
"Here's the enhanced prompt: This is the enhanced prompt",
|
|
"Here is the enhanced prompt: This is the enhanced prompt",
|
|
"**Enhanced Prompt**: This is the enhanced prompt",
|
|
"**Enhanced prompt**: This is the enhanced prompt",
|
|
]
|
|
|
|
for response_with_prefix in test_cases:
|
|
mock_llm.invoke.return_value = MagicMock(content=response_with_prefix)
|
|
|
|
state = PromptEnhancerState(prompt="Test prompt")
|
|
result = prompt_enhancer_node(state)
|
|
|
|
assert result == {"output": "This is the enhanced prompt"}
|
|
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
|
@patch(
|
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
|
{"prompt_enhancer": "basic"},
|
|
)
|
|
def test_whitespace_handling(
|
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
|
):
|
|
"""Test that whitespace is properly stripped from LLM response."""
|
|
mock_get_llm.return_value = mock_llm
|
|
mock_apply_template.return_value = mock_messages
|
|
|
|
# Mock LLM response with extra whitespace
|
|
mock_llm.invoke.return_value = MagicMock(
|
|
content=" \n\n Enhanced prompt \n\n "
|
|
)
|
|
|
|
state = PromptEnhancerState(prompt="Test prompt")
|
|
result = prompt_enhancer_node(state)
|
|
|
|
assert result == {"output": "Enhanced prompt"}
|