diff --git a/src/graph/nodes.py b/src/graph/nodes.py index f85ab45..56ee8df 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -5,7 +5,7 @@ import json import logging import os from functools import partial -from typing import Annotated, Literal +from typing import Any, Annotated, Literal from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.runnables import RunnableConfig @@ -361,6 +361,34 @@ def planner_node( ) +def extract_plan_content(plan_data: str | dict | Any) -> str: + """ + Safely extract plan content from different types of plan data. + + Args: + plan_data: The plan data which can be a string, AIMessage, or dict + + Returns: + str: The plan content as a string (JSON string for dict inputs, or + extracted/original string for other types) + """ + if isinstance(plan_data, str): + # If it's already a string, return as is + return plan_data + elif hasattr(plan_data, 'content') and isinstance(plan_data.content, str): + # If it's an AIMessage or similar object with a content attribute + logger.debug(f"Extracting plan content from message object of type {type(plan_data).__name__}") + return plan_data.content + elif isinstance(plan_data, dict): + # If it's already a dictionary, convert to JSON string + logger.debug("Converting plan dictionary to JSON string") + return json.dumps(plan_data) + else: + # For any other type, try to convert to string + logger.warning(f"Unexpected plan data type {type(plan_data).__name__}, attempting to convert to string") + return str(plan_data) + + def human_feedback_node( state: State, config: RunnableConfig ) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]: @@ -406,7 +434,13 @@ def human_feedback_node( plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 goto = "research_team" try: - current_plan = repair_json_output(current_plan) + # Safely extract plan content from different types (string, AIMessage, dict) + original_plan = current_plan + current_plan_content = extract_plan_content(current_plan) + logger.debug(f"Extracted plan content type: {type(current_plan_content).__name__}") + + # Repair the JSON output + current_plan = repair_json_output(current_plan_content) # increment the plan iterations plan_iterations += 1 # parse the plan @@ -414,8 +448,10 @@ def human_feedback_node( # Validate and fix plan to ensure web search requirements are met configurable = Configuration.from_runnable_config(config) new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search) - except json.JSONDecodeError: - logger.warning("Planner response is not a valid JSON") + except (json.JSONDecodeError, AttributeError) as e: + logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}") + if isinstance(current_plan, dict) and "content" in original_plan: + logger.warning(f"Plan appears to be an AIMessage object with content field") if plan_iterations > 1: # the plan_iterations is increased before this check return Command( update=preserve_state_meta_fields(state), diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index d788dc6..1a7eebf 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -12,8 +12,155 @@ from src.graph.nodes import ( planner_node, reporter_node, researcher_node, + extract_plan_content, ) + +class TestExtractPlanContent: + """Test cases for the extract_plan_content function.""" + + def test_extract_plan_content_with_string(self): + """Test that extract_plan_content returns the input string as-is.""" + plan_json_str = '{"locale": "en-US", "has_enough_context": false, "title": "Test Plan"}' + result = extract_plan_content(plan_json_str) + assert result == plan_json_str + + def test_extract_plan_content_with_ai_message(self): + """Test that extract_plan_content extracts content from an AIMessage-like object.""" + # Create a mock AIMessage object + class MockAIMessage: + def __init__(self, content): + self.content = content + + plan_content = '{"locale": "zh-CN", "has_enough_context": false, "title": "测试计划"}' + plan_message = MockAIMessage(plan_content) + + result = extract_plan_content(plan_message) + assert result == plan_content + + def test_extract_plan_content_with_dict(self): + """Test that extract_plan_content converts a dictionary to JSON string.""" + plan_dict = { + "locale": "fr-FR", + "has_enough_context": True, + "title": "Plan de test", + "steps": [] + } + expected_json = json.dumps(plan_dict) + + result = extract_plan_content(plan_dict) + assert result == expected_json + + def test_extract_plan_content_with_other_type(self): + """Test that extract_plan_content converts other types to string.""" + plan_value = 12345 + expected_string = "12345" + + result = extract_plan_content(plan_value) + assert result == expected_string + + def test_extract_plan_content_with_complex_dict(self): + """Test that extract_plan_content handles complex nested dictionaries.""" + plan_dict = { + "locale": "zh-CN", + "has_enough_context": False, + "title": "埃菲尔铁塔与世界最高建筑高度比较研究计划", + "thought": "要回答埃菲尔铁塔比世界最高建筑高多少倍的问题,我们需要知道埃菲尔铁塔的高度以及当前世界最高建筑的高度。", + "steps": [ + { + "need_search": True, + "title": "收集埃菲尔铁塔和世界最高建筑的高度数据", + "description": "从可靠来源检索埃菲尔铁塔的确切高度以及目前被公认为世界最高建筑的建筑物及其高度数据。", + "step_type": "research" + }, + { + "need_search": True, + "title": "查找其他超高建筑作为对比基准", + "description": "获取其他具有代表性的超高建筑的高度数据,以提供更全面的比较背景。", + "step_type": "research" + } + ] + } + + result = extract_plan_content(plan_dict) + # Verify the result can be parsed back to a dictionary + parsed_result = json.loads(result) + assert parsed_result == plan_dict + + def test_extract_plan_content_with_non_string_content(self): + """Test that extract_plan_content handles AIMessage with non-string content.""" + class MockAIMessageWithNonStringContent: + def __init__(self, content): + self.content = content + + # Test with non-string content (should not be extracted) + plan_content = 12345 + plan_message = MockAIMessageWithNonStringContent(plan_content) + + result = extract_plan_content(plan_message) + # Should convert the entire object to string since content is not a string + assert isinstance(result, str) + assert "MockAIMessageWithNonStringContent" in result + + def test_extract_plan_content_with_empty_string(self): + """Test that extract_plan_content handles empty strings.""" + empty_string = "" + result = extract_plan_content(empty_string) + assert result == "" + + def test_extract_plan_content_with_empty_dict(self): + """Test that extract_plan_content handles empty dictionaries.""" + empty_dict = {} + expected_json = "{}" + + result = extract_plan_content(empty_dict) + assert result == expected_json + + def test_extract_plan_content_issue_703_case(self): + """Test that extract_plan_content handles the specific case from issue #703.""" + # This is the exact structure that was causing the error in issue #703 + class MockAIMessageFromIssue703: + def __init__(self, content): + self.content = content + self.additional_kwargs = {} + self.response_metadata = {'finish_reason': 'stop', 'model_name': 'qwen-max-latest'} + self.type = 'ai' + self.id = 'run--ebc626af-3845-472b-aeee-acddebf5a4ea' + self.example = False + self.tool_calls = [] + self.invalid_tool_calls = [] + + plan_content = '''{ + "locale": "zh-CN", + "has_enough_context": false, + "thought": "要回答埃菲尔铁塔比世界最高建筑高多少倍的问题,我们需要知道埃菲尔铁塔的高度以及当前世界最高建筑的高度。", + "title": "埃菲尔铁塔与世界最高建筑高度比较研究计划", + "steps": [ + { + "need_search": true, + "title": "收集埃菲尔铁塔和世界最高建筑的高度数据", + "description": "从可靠来源检索埃菲尔铁塔的确切高度以及目前被公认为世界最高建筑的建筑物及其高度数据。", + "step_type": "research" + } + ] + }''' + + plan_message = MockAIMessageFromIssue703(plan_content) + + # Extract the content + result = extract_plan_content(plan_message) + + # Verify the extracted content is the same as the original + assert result == plan_content + + # Verify the extracted content can be parsed as JSON + parsed_result = json.loads(result) + assert parsed_result["locale"] == "zh-CN" + assert parsed_result["title"] == "埃菲尔铁塔与世界最高建筑高度比较研究计划" + assert len(parsed_result["steps"]) == 1 + assert parsed_result["steps"][0]["title"] == "收集埃菲尔铁塔和世界最高建筑的高度数据" + + # 在这里 mock 掉 get_llm_by_type,避免 ValueError with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()): from langchain_core.messages import HumanMessage