fix: Plan model_validate throw exception in auto_accepted_plan (#1111)

* fix: Plan.model_validate throw exception in auto_accepted_plan

* improve log

* add UT

* fix ci

* reverse uv.lock

* add blank

* fix
This commit is contained in:
Xun
2026-03-12 17:13:39 +08:00
committed by GitHub
parent 172ba2d7ad
commit 2ab2876580
3 changed files with 143 additions and 12 deletions

View File

@@ -1,4 +1,5 @@
{ {
"python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python",
"python.testing.pytestArgs": [ "python.testing.pytestArgs": [
"tests" "tests"
], ],

View File

@@ -7,6 +7,7 @@ import os
import re import re
from functools import partial from functools import partial
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import ValidationError
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@@ -496,8 +497,9 @@ def human_feedback_node(
) )
# if the plan is accepted, run the following node # if the plan is accepted, run the following node
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 plan_iterations = (state.get("plan_iterations") or 0) + 1
goto = "research_team" goto = "research_team"
configurable = Configuration.from_runnable_config(config)
try: try:
# Safely extract plan content from different types (string, AIMessage, dict) # Safely extract plan content from different types (string, AIMessage, dict)
original_plan = current_plan original_plan = current_plan
@@ -508,18 +510,55 @@ def human_feedback_node(
current_plan = json.loads(current_plan) current_plan = json.loads(current_plan)
current_plan_content = extract_plan_content(current_plan) current_plan_content = extract_plan_content(current_plan)
# increment the plan iterations
plan_iterations += 1
# parse the plan # parse the plan
new_plan = json.loads(repair_json_output(current_plan_content)) new_plan = json.loads(repair_json_output(current_plan_content))
# Some models may return only a raw steps list instead of a full plan object.
# Normalize to Plan schema to avoid ValidationError in Plan.model_validate().
if isinstance(new_plan, list):
logger.warning("Planner returned plan as list; normalizing to dict with inferred metadata")
new_plan = {
"locale": state.get("locale", "en-US"),
"has_enough_context": False,
"thought": "",
"title": state.get("research_topic") or "Research Plan",
"steps": new_plan,
}
elif not isinstance(new_plan, dict):
raise ValueError(f"Unsupported plan type after parsing: {type(new_plan).__name__}")
# Fill required fields if partially missing.
new_plan.setdefault("locale", state.get("locale", "en-US"))
new_plan.setdefault("has_enough_context", False)
new_plan.setdefault("thought", "")
if not new_plan.get("title"):
new_plan["title"] = state.get("research_topic") or "Research Plan"
if "steps" not in new_plan or new_plan.get("steps") is None:
new_plan["steps"] = []
# Validate and fix plan to ensure web search requirements are met # Validate and fix plan to ensure web search requirements are met
configurable = Configuration.from_runnable_config(config) # after normalization so list-shaped plans are also enforced.
new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search, configurable.enable_web_search) new_plan = validate_and_fix_plan(
except (json.JSONDecodeError, AttributeError, ValueError) as e: new_plan,
configurable.enforce_web_search,
configurable.enable_web_search,
)
validated_plan = Plan.model_validate(new_plan)
except (json.JSONDecodeError, AttributeError, ValueError, ValidationError) as e:
logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}") logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}")
if isinstance(current_plan, dict) and "content" in original_plan: if isinstance(current_plan, dict) and "content" in original_plan:
logger.warning(f"Plan appears to be an AIMessage object with content field") logger.warning(f"Plan appears to be an AIMessage object with content field")
if plan_iterations > 1: # the plan_iterations is increased before this check if plan_iterations < configurable.max_plan_iterations:
return Command(
update={
"plan_iterations": plan_iterations,
**preserve_state_meta_fields(state),
},
goto="planner"
)
if plan_iterations > 1:
return Command( return Command(
update=preserve_state_meta_fields(state), update=preserve_state_meta_fields(state),
goto="reporter" goto="reporter"
@@ -532,7 +571,7 @@ def human_feedback_node(
# Build update dict with safe locale handling # Build update dict with safe locale handling
update_dict = { update_dict = {
"current_plan": Plan.model_validate(new_plan), "current_plan": validated_plan,
"plan_iterations": plan_iterations, "plan_iterations": plan_iterations,
**preserve_state_meta_fields(state), **preserve_state_meta_fields(state),
} }
@@ -907,7 +946,7 @@ def reporter_node(state: State, config: RunnableConfig):
response_content = re.sub( response_content = re.sub(
r"<think>[\s\S]*?</think>", "", response_content r"<think>[\s\S]*?</think>", "", response_content
).strip() ).strip()
logger.info(f"reporter response: {response_content}") logger.debug(f"reporter response length: {len(response_content)}")
return { return {
"final_report": response_content, "final_report": response_content,

View File

@@ -2,6 +2,7 @@ import json
from collections import namedtuple from collections import namedtuple
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from pydantic import ValidationError
import pytest import pytest
from src.graph.nodes import ( from src.graph.nodes import (
@@ -825,12 +826,102 @@ def test_human_feedback_node_json_decode_error_first_iteration(
state = dict(mock_state_base) state = dict(mock_state_base)
state["auto_accepted_plan"] = True state["auto_accepted_plan"] = True
state["plan_iterations"] = 0 state["plan_iterations"] = 0
with patch( mock_configurable = MagicMock()
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) mock_configurable.max_plan_iterations = 3
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
): ):
result = human_feedback_node(state, mock_config) result = human_feedback_node(state, mock_config)
assert isinstance(result, Command) assert isinstance(result, Command)
assert result.goto == "__end__" assert result.goto == "planner"
assert result.update["plan_iterations"] == 1
def test_human_feedback_node_model_validate_error(mock_state_base, mock_config):
# Plan.model_validate raises ValidationError, should enter error handling path
from pydantic import BaseModel
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
# Build a real ValidationError instance from pydantic
class DummyModel(BaseModel):
value: int
try:
DummyModel.model_validate({"value": "not_an_int"})
except ValidationError as validation_error:
raised_validation_error = validation_error
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
mock_configurable.enforce_web_search = False
mock_configurable.enable_web_search = True
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.Plan.model_validate",
side_effect=raised_validation_error,
),
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "planner"
assert result.update["plan_iterations"] == 1
def test_human_feedback_node_list_plan_runs_enforcement_after_normalization(
mock_state_base, mock_config
):
# Regression: when plan content is a list, normalization happens first,
# then validate_and_fix_plan must still run on the normalized dict.
raw_list_plan = [
{
"need_search": False,
"title": "Only Step",
"description": "Collect baseline info",
# intentionally missing step_type
}
]
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
state["current_plan"] = json.dumps({"content": [json.dumps(raw_list_plan)]})
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
mock_configurable.enforce_web_search = True
mock_configurable.enable_web_search = True
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
normalized_plan = result.update["current_plan"]
assert isinstance(normalized_plan, dict)
assert isinstance(normalized_plan.get("steps"), list)
assert len(normalized_plan["steps"]) == 1
# validate_and_fix_plan effects should be visible after normalization
assert normalized_plan["steps"][0]["step_type"] == "research"
assert normalized_plan["steps"][0]["need_search"] is True
def test_human_feedback_node_json_decode_error_second_iteration( def test_human_feedback_node_json_decode_error_second_iteration(