mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-18 03:54:46 +08:00
fix: fix unit test cases for prompt enhancer (#431)
This commit is contained in:
@@ -14,7 +14,75 @@ from src.config.report_style import ReportStyle
|
|||||||
def mock_llm():
|
def mock_llm():
|
||||||
"""Mock LLM that returns a test response."""
|
"""Mock LLM that returns a test response."""
|
||||||
llm = MagicMock()
|
llm = MagicMock()
|
||||||
llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
|
llm.invoke.return_value = MagicMock(
|
||||||
|
content="""Thoughts: LLM thinks a lot
|
||||||
|
<enhanced_prompt>
|
||||||
|
Enhanced test prompt
|
||||||
|
</enhanced_prompt>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_xml_with_whitespace():
|
||||||
|
"""Mock LLM that returns XML response with extra whitespace."""
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.invoke.return_value = MagicMock(
|
||||||
|
content="""
|
||||||
|
Some thoughts here...
|
||||||
|
|
||||||
|
<enhanced_prompt>
|
||||||
|
|
||||||
|
Enhanced prompt with whitespace
|
||||||
|
|
||||||
|
</enhanced_prompt>
|
||||||
|
|
||||||
|
Additional content after XML
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_xml_multiline():
|
||||||
|
"""Mock LLM that returns XML response with multiline content."""
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.invoke.return_value = MagicMock(
|
||||||
|
content="""
|
||||||
|
<enhanced_prompt>
|
||||||
|
This is a multiline enhanced prompt
|
||||||
|
that spans multiple lines
|
||||||
|
and includes various formatting.
|
||||||
|
|
||||||
|
It should preserve the structure.
|
||||||
|
</enhanced_prompt>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_no_xml():
|
||||||
|
"""Mock LLM that returns response without XML tags."""
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.invoke.return_value = MagicMock(
|
||||||
|
content="Enhanced Prompt: This is an enhanced prompt without XML tags"
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_malformed_xml():
|
||||||
|
"""Mock LLM that returns response with malformed XML."""
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.invoke.return_value = MagicMock(
|
||||||
|
content="""
|
||||||
|
<enhanced_prompt>
|
||||||
|
This XML tag is not properly closed
|
||||||
|
<enhanced_prompt>
|
||||||
|
"""
|
||||||
|
)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
@@ -217,3 +285,241 @@ class TestPromptEnhancerNode:
|
|||||||
result = prompt_enhancer_node(state)
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
assert result == {"output": "Enhanced prompt"}
|
assert result == {"output": "Enhanced prompt"}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_xml_with_whitespace_handling(
|
||||||
|
self,
|
||||||
|
mock_get_llm,
|
||||||
|
mock_apply_template,
|
||||||
|
mock_llm_xml_with_whitespace,
|
||||||
|
mock_messages,
|
||||||
|
):
|
||||||
|
"""Test XML extraction with extra whitespace inside tags."""
|
||||||
|
mock_get_llm.return_value = mock_llm_xml_with_whitespace
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
assert result == {"output": "Enhanced prompt with whitespace"}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_xml_multiline_content(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm_xml_multiline, mock_messages
|
||||||
|
):
|
||||||
|
"""Test XML extraction with multiline content."""
|
||||||
|
mock_get_llm.return_value = mock_llm_xml_multiline
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
expected_output = """This is a multiline enhanced prompt
|
||||||
|
that spans multiple lines
|
||||||
|
and includes various formatting.
|
||||||
|
|
||||||
|
It should preserve the structure."""
|
||||||
|
assert result == {"output": expected_output}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_fallback_to_prefix_removal(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm_no_xml, mock_messages
|
||||||
|
):
|
||||||
|
"""Test fallback to prefix removal when no XML tags are found."""
|
||||||
|
mock_get_llm.return_value = mock_llm_no_xml
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
assert result == {"output": "This is an enhanced prompt without XML tags"}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_malformed_xml_fallback(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm_malformed_xml, mock_messages
|
||||||
|
):
|
||||||
|
"""Test handling of malformed XML tags."""
|
||||||
|
mock_get_llm.return_value = mock_llm_malformed_xml
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
# Should fall back to using the entire content since XML is malformed
|
||||||
|
expected_content = """<enhanced_prompt>
|
||||||
|
This XML tag is not properly closed
|
||||||
|
<enhanced_prompt>"""
|
||||||
|
assert result == {"output": expected_content}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_case_sensitive_prefix_removal(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||||
|
):
|
||||||
|
"""Test that prefix removal is case-sensitive."""
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
# Test case variations that should NOT be removed
|
||||||
|
test_cases = [
|
||||||
|
"ENHANCED PROMPT: This should not be removed",
|
||||||
|
"enhanced prompt: This should not be removed",
|
||||||
|
"Enhanced Prompt This should not be removed", # Missing colon
|
||||||
|
"Enhanced Prompt :: This should not be removed", # Double colon
|
||||||
|
]
|
||||||
|
|
||||||
|
for response_content in test_cases:
|
||||||
|
mock_llm.invoke.return_value = MagicMock(content=response_content)
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
# Should return the full content since prefix doesn't match exactly
|
||||||
|
assert result == {"output": response_content}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_prefix_with_extra_whitespace(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||||
|
):
|
||||||
|
"""Test prefix removal with extra whitespace after colon."""
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
("Enhanced Prompt: This has extra spaces", "This has extra spaces"),
|
||||||
|
("Enhanced prompt:\t\tThis has tabs", "This has tabs"),
|
||||||
|
("Here's the enhanced prompt:\n\nThis has newlines", "This has newlines"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for response_content, expected_output in test_cases:
|
||||||
|
mock_llm.invoke.return_value = MagicMock(content=response_content)
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
assert result == {"output": expected_output}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_xml_with_special_characters(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||||
|
):
|
||||||
|
"""Test XML extraction with special characters and symbols."""
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
special_content = """<enhanced_prompt>
|
||||||
|
Enhanced prompt with special chars: @#$%^&*()
|
||||||
|
Unicode: 🚀 ✨ 💡
|
||||||
|
Quotes: "double" and 'single'
|
||||||
|
Backslashes: \\n \\t \\r
|
||||||
|
</enhanced_prompt>"""
|
||||||
|
|
||||||
|
mock_llm.invoke.return_value = MagicMock(content=special_content)
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
expected_output = """Enhanced prompt with special chars: @#$%^&*()
|
||||||
|
Unicode: 🚀 ✨ 💡
|
||||||
|
Quotes: "double" and 'single'
|
||||||
|
Backslashes: \\n \\t \\r"""
|
||||||
|
assert result == {"output": expected_output}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_very_long_response(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||||
|
):
|
||||||
|
"""Test handling of very long LLM responses."""
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
# Create a very long response
|
||||||
|
long_content = "This is a very long enhanced prompt. " * 100
|
||||||
|
xml_response = f"<enhanced_prompt>\n{long_content}\n</enhanced_prompt>"
|
||||||
|
|
||||||
|
mock_llm.invoke.return_value = MagicMock(content=xml_response)
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
assert result == {"output": long_content.strip()}
|
||||||
|
assert len(result["output"]) > 1000 # Verify it's actually long
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_empty_response_content(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||||
|
):
|
||||||
|
"""Test handling of empty response content."""
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
mock_llm.invoke.return_value = MagicMock(content="")
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
assert result == {"output": ""}
|
||||||
|
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||||
|
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||||
|
@patch(
|
||||||
|
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||||
|
{"prompt_enhancer": "basic"},
|
||||||
|
)
|
||||||
|
def test_only_whitespace_response(
|
||||||
|
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||||
|
):
|
||||||
|
"""Test handling of response with only whitespace."""
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
|
||||||
|
mock_llm.invoke.return_value = MagicMock(content=" \n\n\t\t ")
|
||||||
|
|
||||||
|
state = PromptEnhancerState(prompt="Test prompt")
|
||||||
|
result = prompt_enhancer_node(state)
|
||||||
|
|
||||||
|
assert result == {"output": ""}
|
||||||
|
|||||||
Reference in New Issue
Block a user