From 2ab28765803d4d9582aaa8f2f3355137d154e273 Mon Sep 17 00:00:00 2001 From: Xun Date: Thu, 12 Mar 2026 17:13:39 +0800 Subject: [PATCH] fix: Plan model_validate throw exception in auto_accepted_plan (#1111) * fix: Plan.model_validate throw exception in auto_accepted_plan * improve log * add UT * fix ci * reverse uv.lock * add blank * fix --- .vscode/settings.json | 1 + src/graph/nodes.py | 57 ++++++++++++++++--- tests/integration/test_nodes.py | 97 ++++++++++++++++++++++++++++++++- 3 files changed, 143 insertions(+), 12 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 9b38853..1b314fd 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,5 @@ { + "python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python", "python.testing.pytestArgs": [ "tests" ], diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 9f207c0..28e0829 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -7,6 +7,7 @@ import os import re from functools import partial from typing import Annotated, Any, Literal +from pydantic import ValidationError from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig @@ -496,8 +497,9 @@ def human_feedback_node( ) # if the plan is accepted, run the following node - plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 + plan_iterations = (state.get("plan_iterations") or 0) + 1 goto = "research_team" + configurable = Configuration.from_runnable_config(config) try: # Safely extract plan content from different types (string, AIMessage, dict) original_plan = current_plan @@ -508,18 +510,55 @@ def human_feedback_node( current_plan = json.loads(current_plan) current_plan_content = extract_plan_content(current_plan) - # increment the plan iterations - plan_iterations += 1 # parse the plan new_plan = json.loads(repair_json_output(current_plan_content)) + + # Some models may return only a raw steps list instead of a full plan object. + # Normalize to Plan schema to avoid ValidationError in Plan.model_validate(). + if isinstance(new_plan, list): + logger.warning("Planner returned plan as list; normalizing to dict with inferred metadata") + new_plan = { + "locale": state.get("locale", "en-US"), + "has_enough_context": False, + "thought": "", + "title": state.get("research_topic") or "Research Plan", + "steps": new_plan, + } + elif not isinstance(new_plan, dict): + raise ValueError(f"Unsupported plan type after parsing: {type(new_plan).__name__}") + + # Fill required fields if partially missing. + new_plan.setdefault("locale", state.get("locale", "en-US")) + new_plan.setdefault("has_enough_context", False) + new_plan.setdefault("thought", "") + if not new_plan.get("title"): + new_plan["title"] = state.get("research_topic") or "Research Plan" + if "steps" not in new_plan or new_plan.get("steps") is None: + new_plan["steps"] = [] + # Validate and fix plan to ensure web search requirements are met - configurable = Configuration.from_runnable_config(config) - new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search, configurable.enable_web_search) - except (json.JSONDecodeError, AttributeError, ValueError) as e: + # after normalization so list-shaped plans are also enforced. + new_plan = validate_and_fix_plan( + new_plan, + configurable.enforce_web_search, + configurable.enable_web_search, + ) + + validated_plan = Plan.model_validate(new_plan) + + except (json.JSONDecodeError, AttributeError, ValueError, ValidationError) as e: logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}") if isinstance(current_plan, dict) and "content" in original_plan: logger.warning(f"Plan appears to be an AIMessage object with content field") - if plan_iterations > 1: # the plan_iterations is increased before this check + if plan_iterations < configurable.max_plan_iterations: + return Command( + update={ + "plan_iterations": plan_iterations, + **preserve_state_meta_fields(state), + }, + goto="planner" + ) + if plan_iterations > 1: return Command( update=preserve_state_meta_fields(state), goto="reporter" @@ -532,7 +571,7 @@ def human_feedback_node( # Build update dict with safe locale handling update_dict = { - "current_plan": Plan.model_validate(new_plan), + "current_plan": validated_plan, "plan_iterations": plan_iterations, **preserve_state_meta_fields(state), } @@ -907,7 +946,7 @@ def reporter_node(state: State, config: RunnableConfig): response_content = re.sub( r"[\s\S]*?", "", response_content ).strip() - logger.info(f"reporter response: {response_content}") + logger.debug(f"reporter response length: {len(response_content)}") return { "final_report": response_content, diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index ce7b9c6..2c43745 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -2,6 +2,7 @@ import json from collections import namedtuple from unittest.mock import MagicMock, patch +from pydantic import ValidationError import pytest from src.graph.nodes import ( @@ -825,12 +826,102 @@ def test_human_feedback_node_json_decode_error_first_iteration( state = dict(mock_state_base) state["auto_accepted_plan"] = True state["plan_iterations"] = 0 - with patch( - "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) + mock_configurable = MagicMock() + mock_configurable.max_plan_iterations = 3 + with ( + patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable, + ), + patch( + "src.graph.nodes.json.loads", + side_effect=json.JSONDecodeError("err", "doc", 0), + ), ): result = human_feedback_node(state, mock_config) assert isinstance(result, Command) - assert result.goto == "__end__" + assert result.goto == "planner" + assert result.update["plan_iterations"] == 1 + +def test_human_feedback_node_model_validate_error(mock_state_base, mock_config): + # Plan.model_validate raises ValidationError, should enter error handling path + from pydantic import BaseModel + + state = dict(mock_state_base) + state["auto_accepted_plan"] = True + state["plan_iterations"] = 0 + + # Build a real ValidationError instance from pydantic + class DummyModel(BaseModel): + value: int + + try: + DummyModel.model_validate({"value": "not_an_int"}) + except ValidationError as validation_error: + raised_validation_error = validation_error + + mock_configurable = MagicMock() + mock_configurable.max_plan_iterations = 3 + mock_configurable.enforce_web_search = False + mock_configurable.enable_web_search = True + + with ( + patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable, + ), + patch( + "src.graph.nodes.Plan.model_validate", + side_effect=raised_validation_error, + ), + ): + result = human_feedback_node(state, mock_config) + assert isinstance(result, Command) + assert result.goto == "planner" + assert result.update["plan_iterations"] == 1 + +def test_human_feedback_node_list_plan_runs_enforcement_after_normalization( + mock_state_base, mock_config +): + # Regression: when plan content is a list, normalization happens first, + # then validate_and_fix_plan must still run on the normalized dict. + raw_list_plan = [ + { + "need_search": False, + "title": "Only Step", + "description": "Collect baseline info", + # intentionally missing step_type + } + ] + + state = dict(mock_state_base) + state["auto_accepted_plan"] = True + state["plan_iterations"] = 0 + state["current_plan"] = json.dumps({"content": [json.dumps(raw_list_plan)]}) + + mock_configurable = MagicMock() + mock_configurable.max_plan_iterations = 3 + mock_configurable.enforce_web_search = True + mock_configurable.enable_web_search = True + + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable, + ): + result = human_feedback_node(state, mock_config) + + assert isinstance(result, Command) + assert result.goto == "research_team" + assert result.update["plan_iterations"] == 1 + + normalized_plan = result.update["current_plan"] + assert isinstance(normalized_plan, dict) + assert isinstance(normalized_plan.get("steps"), list) + assert len(normalized_plan["steps"]) == 1 + + # validate_and_fix_plan effects should be visible after normalization + assert normalized_plan["steps"][0]["step_type"] == "research" + assert normalized_plan["steps"][0]["need_search"] is True def test_human_feedback_node_json_decode_error_second_iteration(