mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
fix: fix unit test cases for prompt enhancer (#431)
This commit is contained in:
@@ -14,7 +14,75 @@ from src.config.report_style import ReportStyle
|
||||
def mock_llm():
|
||||
"""Mock LLM that returns a test response."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""Thoughts: LLM thinks a lot
|
||||
<enhanced_prompt>
|
||||
Enhanced test prompt
|
||||
</enhanced_prompt>
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_xml_with_whitespace():
|
||||
"""Mock LLM that returns XML response with extra whitespace."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""
|
||||
Some thoughts here...
|
||||
|
||||
<enhanced_prompt>
|
||||
|
||||
Enhanced prompt with whitespace
|
||||
|
||||
</enhanced_prompt>
|
||||
|
||||
Additional content after XML
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_xml_multiline():
|
||||
"""Mock LLM that returns XML response with multiline content."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""
|
||||
<enhanced_prompt>
|
||||
This is a multiline enhanced prompt
|
||||
that spans multiple lines
|
||||
and includes various formatting.
|
||||
|
||||
It should preserve the structure.
|
||||
</enhanced_prompt>
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_no_xml():
|
||||
"""Mock LLM that returns response without XML tags."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="Enhanced Prompt: This is an enhanced prompt without XML tags"
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_malformed_xml():
|
||||
"""Mock LLM that returns response with malformed XML."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""
|
||||
<enhanced_prompt>
|
||||
This XML tag is not properly closed
|
||||
<enhanced_prompt>
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@@ -217,3 +285,241 @@ class TestPromptEnhancerNode:
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_xml_with_whitespace_handling(
|
||||
self,
|
||||
mock_get_llm,
|
||||
mock_apply_template,
|
||||
mock_llm_xml_with_whitespace,
|
||||
mock_messages,
|
||||
):
|
||||
"""Test XML extraction with extra whitespace inside tags."""
|
||||
mock_get_llm.return_value = mock_llm_xml_with_whitespace
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt with whitespace"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_xml_multiline_content(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm_xml_multiline, mock_messages
|
||||
):
|
||||
"""Test XML extraction with multiline content."""
|
||||
mock_get_llm.return_value = mock_llm_xml_multiline
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
expected_output = """This is a multiline enhanced prompt
|
||||
that spans multiple lines
|
||||
and includes various formatting.
|
||||
|
||||
It should preserve the structure."""
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_fallback_to_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm_no_xml, mock_messages
|
||||
):
|
||||
"""Test fallback to prefix removal when no XML tags are found."""
|
||||
mock_get_llm.return_value = mock_llm_no_xml
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "This is an enhanced prompt without XML tags"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_malformed_xml_fallback(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm_malformed_xml, mock_messages
|
||||
):
|
||||
"""Test handling of malformed XML tags."""
|
||||
mock_get_llm.return_value = mock_llm_malformed_xml
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should fall back to using the entire content since XML is malformed
|
||||
expected_content = """<enhanced_prompt>
|
||||
This XML tag is not properly closed
|
||||
<enhanced_prompt>"""
|
||||
assert result == {"output": expected_content}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_case_sensitive_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that prefix removal is case-sensitive."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Test case variations that should NOT be removed
|
||||
test_cases = [
|
||||
"ENHANCED PROMPT: This should not be removed",
|
||||
"enhanced prompt: This should not be removed",
|
||||
"Enhanced Prompt This should not be removed", # Missing colon
|
||||
"Enhanced Prompt :: This should not be removed", # Double colon
|
||||
]
|
||||
|
||||
for response_content in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_content)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return the full content since prefix doesn't match exactly
|
||||
assert result == {"output": response_content}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prefix_with_extra_whitespace(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prefix removal with extra whitespace after colon."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
test_cases = [
|
||||
("Enhanced Prompt: This has extra spaces", "This has extra spaces"),
|
||||
("Enhanced prompt:\t\tThis has tabs", "This has tabs"),
|
||||
("Here's the enhanced prompt:\n\nThis has newlines", "This has newlines"),
|
||||
]
|
||||
|
||||
for response_content, expected_output in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_content)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_xml_with_special_characters(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test XML extraction with special characters and symbols."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
special_content = """<enhanced_prompt>
|
||||
Enhanced prompt with special chars: @#$%^&*()
|
||||
Unicode: 🚀 ✨ 💡
|
||||
Quotes: "double" and 'single'
|
||||
Backslashes: \\n \\t \\r
|
||||
</enhanced_prompt>"""
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content=special_content)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
expected_output = """Enhanced prompt with special chars: @#$%^&*()
|
||||
Unicode: 🚀 ✨ 💡
|
||||
Quotes: "double" and 'single'
|
||||
Backslashes: \\n \\t \\r"""
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_very_long_response(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test handling of very long LLM responses."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Create a very long response
|
||||
long_content = "This is a very long enhanced prompt. " * 100
|
||||
xml_response = f"<enhanced_prompt>\n{long_content}\n</enhanced_prompt>"
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content=xml_response)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": long_content.strip()}
|
||||
assert len(result["output"]) > 1000 # Verify it's actually long
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_empty_response_content(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test handling of empty response content."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content="")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": ""}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_only_whitespace_response(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test handling of response with only whitespace."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content=" \n\n\t\t ")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": ""}
|
||||
|
||||
Reference in New Issue
Block a user