diff --git a/tests/unit/prompt_enhancer/graph/test_enhancer_node.py b/tests/unit/prompt_enhancer/graph/test_enhancer_node.py
index 5086df9..a703801 100644
--- a/tests/unit/prompt_enhancer/graph/test_enhancer_node.py
+++ b/tests/unit/prompt_enhancer/graph/test_enhancer_node.py
@@ -14,7 +14,75 @@ from src.config.report_style import ReportStyle
def mock_llm():
"""Mock LLM that returns a test response."""
llm = MagicMock()
- llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
+ llm.invoke.return_value = MagicMock(
+ content="""Thoughts: LLM thinks a lot
+
+Enhanced test prompt
+
+"""
+ )
+ return llm
+
+
+@pytest.fixture
+def mock_llm_xml_with_whitespace():
+ """Mock LLM that returns XML response with extra whitespace."""
+ llm = MagicMock()
+ llm.invoke.return_value = MagicMock(
+ content="""
+Some thoughts here...
+
+
+
+ Enhanced prompt with whitespace
+
+
+
+Additional content after XML
+"""
+ )
+ return llm
+
+
+@pytest.fixture
+def mock_llm_xml_multiline():
+ """Mock LLM that returns XML response with multiline content."""
+ llm = MagicMock()
+ llm.invoke.return_value = MagicMock(
+ content="""
+
+This is a multiline enhanced prompt
+that spans multiple lines
+and includes various formatting.
+
+It should preserve the structure.
+
+"""
+ )
+ return llm
+
+
+@pytest.fixture
+def mock_llm_no_xml():
+ """Mock LLM that returns response without XML tags."""
+ llm = MagicMock()
+ llm.invoke.return_value = MagicMock(
+ content="Enhanced Prompt: This is an enhanced prompt without XML tags"
+ )
+ return llm
+
+
+@pytest.fixture
+def mock_llm_malformed_xml():
+ """Mock LLM that returns response with malformed XML."""
+ llm = MagicMock()
+ llm.invoke.return_value = MagicMock(
+ content="""
+
+This XML tag is not properly closed
+
+"""
+ )
return llm
@@ -217,3 +285,241 @@ class TestPromptEnhancerNode:
result = prompt_enhancer_node(state)
assert result == {"output": "Enhanced prompt"}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_xml_with_whitespace_handling(
+ self,
+ mock_get_llm,
+ mock_apply_template,
+ mock_llm_xml_with_whitespace,
+ mock_messages,
+ ):
+ """Test XML extraction with extra whitespace inside tags."""
+ mock_get_llm.return_value = mock_llm_xml_with_whitespace
+ mock_apply_template.return_value = mock_messages
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ assert result == {"output": "Enhanced prompt with whitespace"}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_xml_multiline_content(
+ self, mock_get_llm, mock_apply_template, mock_llm_xml_multiline, mock_messages
+ ):
+ """Test XML extraction with multiline content."""
+ mock_get_llm.return_value = mock_llm_xml_multiline
+ mock_apply_template.return_value = mock_messages
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ expected_output = """This is a multiline enhanced prompt
+that spans multiple lines
+and includes various formatting.
+
+It should preserve the structure."""
+ assert result == {"output": expected_output}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_fallback_to_prefix_removal(
+ self, mock_get_llm, mock_apply_template, mock_llm_no_xml, mock_messages
+ ):
+ """Test fallback to prefix removal when no XML tags are found."""
+ mock_get_llm.return_value = mock_llm_no_xml
+ mock_apply_template.return_value = mock_messages
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ assert result == {"output": "This is an enhanced prompt without XML tags"}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_malformed_xml_fallback(
+ self, mock_get_llm, mock_apply_template, mock_llm_malformed_xml, mock_messages
+ ):
+ """Test handling of malformed XML tags."""
+ mock_get_llm.return_value = mock_llm_malformed_xml
+ mock_apply_template.return_value = mock_messages
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ # Should fall back to using the entire content since XML is malformed
+ expected_content = """
+This XML tag is not properly closed
+"""
+ assert result == {"output": expected_content}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_case_sensitive_prefix_removal(
+ self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
+ ):
+ """Test that prefix removal is case-sensitive."""
+ mock_get_llm.return_value = mock_llm
+ mock_apply_template.return_value = mock_messages
+
+ # Test case variations that should NOT be removed
+ test_cases = [
+ "ENHANCED PROMPT: This should not be removed",
+ "enhanced prompt: This should not be removed",
+ "Enhanced Prompt This should not be removed", # Missing colon
+ "Enhanced Prompt :: This should not be removed", # Double colon
+ ]
+
+ for response_content in test_cases:
+ mock_llm.invoke.return_value = MagicMock(content=response_content)
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ # Should return the full content since prefix doesn't match exactly
+ assert result == {"output": response_content}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_prefix_with_extra_whitespace(
+ self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
+ ):
+ """Test prefix removal with extra whitespace after colon."""
+ mock_get_llm.return_value = mock_llm
+ mock_apply_template.return_value = mock_messages
+
+ test_cases = [
+ ("Enhanced Prompt: This has extra spaces", "This has extra spaces"),
+ ("Enhanced prompt:\t\tThis has tabs", "This has tabs"),
+ ("Here's the enhanced prompt:\n\nThis has newlines", "This has newlines"),
+ ]
+
+ for response_content, expected_output in test_cases:
+ mock_llm.invoke.return_value = MagicMock(content=response_content)
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ assert result == {"output": expected_output}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_xml_with_special_characters(
+ self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
+ ):
+ """Test XML extraction with special characters and symbols."""
+ mock_get_llm.return_value = mock_llm
+ mock_apply_template.return_value = mock_messages
+
+ special_content = """
+Enhanced prompt with special chars: @#$%^&*()
+Unicode: 🚀 ✨ 💡
+Quotes: "double" and 'single'
+Backslashes: \\n \\t \\r
+"""
+
+ mock_llm.invoke.return_value = MagicMock(content=special_content)
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ expected_output = """Enhanced prompt with special chars: @#$%^&*()
+Unicode: 🚀 ✨ 💡
+Quotes: "double" and 'single'
+Backslashes: \\n \\t \\r"""
+ assert result == {"output": expected_output}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_very_long_response(
+ self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
+ ):
+ """Test handling of very long LLM responses."""
+ mock_get_llm.return_value = mock_llm
+ mock_apply_template.return_value = mock_messages
+
+ # Create a very long response
+ long_content = "This is a very long enhanced prompt. " * 100
+ xml_response = f"\n{long_content}\n"
+
+ mock_llm.invoke.return_value = MagicMock(content=xml_response)
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ assert result == {"output": long_content.strip()}
+ assert len(result["output"]) > 1000 # Verify it's actually long
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_empty_response_content(
+ self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
+ ):
+ """Test handling of empty response content."""
+ mock_get_llm.return_value = mock_llm
+ mock_apply_template.return_value = mock_messages
+
+ mock_llm.invoke.return_value = MagicMock(content="")
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ assert result == {"output": ""}
+
+ @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
+ @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
+ @patch(
+ "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
+ {"prompt_enhancer": "basic"},
+ )
+ def test_only_whitespace_response(
+ self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
+ ):
+ """Test handling of response with only whitespace."""
+ mock_get_llm.return_value = mock_llm
+ mock_apply_template.return_value = mock_messages
+
+ mock_llm.invoke.return_value = MagicMock(content=" \n\n\t\t ")
+
+ state = PromptEnhancerState(prompt="Test prompt")
+ result = prompt_enhancer_node(state)
+
+ assert result == {"output": ""}