diff --git a/tests/unit/prompt_enhancer/graph/test_enhancer_node.py b/tests/unit/prompt_enhancer/graph/test_enhancer_node.py index 5086df9..a703801 100644 --- a/tests/unit/prompt_enhancer/graph/test_enhancer_node.py +++ b/tests/unit/prompt_enhancer/graph/test_enhancer_node.py @@ -14,7 +14,75 @@ from src.config.report_style import ReportStyle def mock_llm(): """Mock LLM that returns a test response.""" llm = MagicMock() - llm.invoke.return_value = MagicMock(content="Enhanced test prompt") + llm.invoke.return_value = MagicMock( + content="""Thoughts: LLM thinks a lot + +Enhanced test prompt + +""" + ) + return llm + + +@pytest.fixture +def mock_llm_xml_with_whitespace(): + """Mock LLM that returns XML response with extra whitespace.""" + llm = MagicMock() + llm.invoke.return_value = MagicMock( + content=""" +Some thoughts here... + + + + Enhanced prompt with whitespace + + + +Additional content after XML +""" + ) + return llm + + +@pytest.fixture +def mock_llm_xml_multiline(): + """Mock LLM that returns XML response with multiline content.""" + llm = MagicMock() + llm.invoke.return_value = MagicMock( + content=""" + +This is a multiline enhanced prompt +that spans multiple lines +and includes various formatting. + +It should preserve the structure. + +""" + ) + return llm + + +@pytest.fixture +def mock_llm_no_xml(): + """Mock LLM that returns response without XML tags.""" + llm = MagicMock() + llm.invoke.return_value = MagicMock( + content="Enhanced Prompt: This is an enhanced prompt without XML tags" + ) + return llm + + +@pytest.fixture +def mock_llm_malformed_xml(): + """Mock LLM that returns response with malformed XML.""" + llm = MagicMock() + llm.invoke.return_value = MagicMock( + content=""" + +This XML tag is not properly closed + +""" + ) return llm @@ -217,3 +285,241 @@ class TestPromptEnhancerNode: result = prompt_enhancer_node(state) assert result == {"output": "Enhanced prompt"} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_xml_with_whitespace_handling( + self, + mock_get_llm, + mock_apply_template, + mock_llm_xml_with_whitespace, + mock_messages, + ): + """Test XML extraction with extra whitespace inside tags.""" + mock_get_llm.return_value = mock_llm_xml_with_whitespace + mock_apply_template.return_value = mock_messages + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + assert result == {"output": "Enhanced prompt with whitespace"} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_xml_multiline_content( + self, mock_get_llm, mock_apply_template, mock_llm_xml_multiline, mock_messages + ): + """Test XML extraction with multiline content.""" + mock_get_llm.return_value = mock_llm_xml_multiline + mock_apply_template.return_value = mock_messages + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + expected_output = """This is a multiline enhanced prompt +that spans multiple lines +and includes various formatting. + +It should preserve the structure.""" + assert result == {"output": expected_output} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_fallback_to_prefix_removal( + self, mock_get_llm, mock_apply_template, mock_llm_no_xml, mock_messages + ): + """Test fallback to prefix removal when no XML tags are found.""" + mock_get_llm.return_value = mock_llm_no_xml + mock_apply_template.return_value = mock_messages + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + assert result == {"output": "This is an enhanced prompt without XML tags"} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_malformed_xml_fallback( + self, mock_get_llm, mock_apply_template, mock_llm_malformed_xml, mock_messages + ): + """Test handling of malformed XML tags.""" + mock_get_llm.return_value = mock_llm_malformed_xml + mock_apply_template.return_value = mock_messages + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + # Should fall back to using the entire content since XML is malformed + expected_content = """ +This XML tag is not properly closed +""" + assert result == {"output": expected_content} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_case_sensitive_prefix_removal( + self, mock_get_llm, mock_apply_template, mock_llm, mock_messages + ): + """Test that prefix removal is case-sensitive.""" + mock_get_llm.return_value = mock_llm + mock_apply_template.return_value = mock_messages + + # Test case variations that should NOT be removed + test_cases = [ + "ENHANCED PROMPT: This should not be removed", + "enhanced prompt: This should not be removed", + "Enhanced Prompt This should not be removed", # Missing colon + "Enhanced Prompt :: This should not be removed", # Double colon + ] + + for response_content in test_cases: + mock_llm.invoke.return_value = MagicMock(content=response_content) + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + # Should return the full content since prefix doesn't match exactly + assert result == {"output": response_content} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_prefix_with_extra_whitespace( + self, mock_get_llm, mock_apply_template, mock_llm, mock_messages + ): + """Test prefix removal with extra whitespace after colon.""" + mock_get_llm.return_value = mock_llm + mock_apply_template.return_value = mock_messages + + test_cases = [ + ("Enhanced Prompt: This has extra spaces", "This has extra spaces"), + ("Enhanced prompt:\t\tThis has tabs", "This has tabs"), + ("Here's the enhanced prompt:\n\nThis has newlines", "This has newlines"), + ] + + for response_content, expected_output in test_cases: + mock_llm.invoke.return_value = MagicMock(content=response_content) + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + assert result == {"output": expected_output} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_xml_with_special_characters( + self, mock_get_llm, mock_apply_template, mock_llm, mock_messages + ): + """Test XML extraction with special characters and symbols.""" + mock_get_llm.return_value = mock_llm + mock_apply_template.return_value = mock_messages + + special_content = """ +Enhanced prompt with special chars: @#$%^&*() +Unicode: 🚀 ✨ 💡 +Quotes: "double" and 'single' +Backslashes: \\n \\t \\r +""" + + mock_llm.invoke.return_value = MagicMock(content=special_content) + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + expected_output = """Enhanced prompt with special chars: @#$%^&*() +Unicode: 🚀 ✨ 💡 +Quotes: "double" and 'single' +Backslashes: \\n \\t \\r""" + assert result == {"output": expected_output} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_very_long_response( + self, mock_get_llm, mock_apply_template, mock_llm, mock_messages + ): + """Test handling of very long LLM responses.""" + mock_get_llm.return_value = mock_llm + mock_apply_template.return_value = mock_messages + + # Create a very long response + long_content = "This is a very long enhanced prompt. " * 100 + xml_response = f"\n{long_content}\n" + + mock_llm.invoke.return_value = MagicMock(content=xml_response) + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + assert result == {"output": long_content.strip()} + assert len(result["output"]) > 1000 # Verify it's actually long + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_empty_response_content( + self, mock_get_llm, mock_apply_template, mock_llm, mock_messages + ): + """Test handling of empty response content.""" + mock_get_llm.return_value = mock_llm + mock_apply_template.return_value = mock_messages + + mock_llm.invoke.return_value = MagicMock(content="") + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + assert result == {"output": ""} + + @patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template") + @patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type") + @patch( + "src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP", + {"prompt_enhancer": "basic"}, + ) + def test_only_whitespace_response( + self, mock_get_llm, mock_apply_template, mock_llm, mock_messages + ): + """Test handling of response with only whitespace.""" + mock_get_llm.return_value = mock_llm + mock_apply_template.return_value = mock_messages + + mock_llm.invoke.return_value = MagicMock(content=" \n\n\t\t ") + + state = PromptEnhancerState(prompt="Test prompt") + result = prompt_enhancer_node(state) + + assert result == {"output": ""}