fix: Plan model_validate throw exception in auto_accepted_plan (#1111)

* fix: Plan.model_validate throw exception in auto_accepted_plan

* improve log

* add UT

* fix ci

* reverse uv.lock

* add blank

* fix
This commit is contained in:
Xun
2026-03-12 17:13:39 +08:00
committed by GitHub
parent 172ba2d7ad
commit 2ab2876580
3 changed files with 143 additions and 12 deletions

View File

@@ -1,4 +1,5 @@
{
"python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python",
"python.testing.pytestArgs": [
"tests"
],

View File

@@ -7,6 +7,7 @@ import os
import re
from functools import partial
from typing import Annotated, Any, Literal
from pydantic import ValidationError
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
@@ -496,8 +497,9 @@ def human_feedback_node(
)
# if the plan is accepted, run the following node
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
plan_iterations = (state.get("plan_iterations") or 0) + 1
goto = "research_team"
configurable = Configuration.from_runnable_config(config)
try:
# Safely extract plan content from different types (string, AIMessage, dict)
original_plan = current_plan
@@ -508,18 +510,55 @@ def human_feedback_node(
current_plan = json.loads(current_plan)
current_plan_content = extract_plan_content(current_plan)
# increment the plan iterations
plan_iterations += 1
# parse the plan
new_plan = json.loads(repair_json_output(current_plan_content))
# Some models may return only a raw steps list instead of a full plan object.
# Normalize to Plan schema to avoid ValidationError in Plan.model_validate().
if isinstance(new_plan, list):
logger.warning("Planner returned plan as list; normalizing to dict with inferred metadata")
new_plan = {
"locale": state.get("locale", "en-US"),
"has_enough_context": False,
"thought": "",
"title": state.get("research_topic") or "Research Plan",
"steps": new_plan,
}
elif not isinstance(new_plan, dict):
raise ValueError(f"Unsupported plan type after parsing: {type(new_plan).__name__}")
# Fill required fields if partially missing.
new_plan.setdefault("locale", state.get("locale", "en-US"))
new_plan.setdefault("has_enough_context", False)
new_plan.setdefault("thought", "")
if not new_plan.get("title"):
new_plan["title"] = state.get("research_topic") or "Research Plan"
if "steps" not in new_plan or new_plan.get("steps") is None:
new_plan["steps"] = []
# Validate and fix plan to ensure web search requirements are met
configurable = Configuration.from_runnable_config(config)
new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search, configurable.enable_web_search)
except (json.JSONDecodeError, AttributeError, ValueError) as e:
# after normalization so list-shaped plans are also enforced.
new_plan = validate_and_fix_plan(
new_plan,
configurable.enforce_web_search,
configurable.enable_web_search,
)
validated_plan = Plan.model_validate(new_plan)
except (json.JSONDecodeError, AttributeError, ValueError, ValidationError) as e:
logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}")
if isinstance(current_plan, dict) and "content" in original_plan:
logger.warning(f"Plan appears to be an AIMessage object with content field")
if plan_iterations > 1: # the plan_iterations is increased before this check
if plan_iterations < configurable.max_plan_iterations:
return Command(
update={
"plan_iterations": plan_iterations,
**preserve_state_meta_fields(state),
},
goto="planner"
)
if plan_iterations > 1:
return Command(
update=preserve_state_meta_fields(state),
goto="reporter"
@@ -532,7 +571,7 @@ def human_feedback_node(
# Build update dict with safe locale handling
update_dict = {
"current_plan": Plan.model_validate(new_plan),
"current_plan": validated_plan,
"plan_iterations": plan_iterations,
**preserve_state_meta_fields(state),
}
@@ -907,7 +946,7 @@ def reporter_node(state: State, config: RunnableConfig):
response_content = re.sub(
r"<think>[\s\S]*?</think>", "", response_content
).strip()
logger.info(f"reporter response: {response_content}")
logger.debug(f"reporter response length: {len(response_content)}")
return {
"final_report": response_content,

View File

@@ -2,6 +2,7 @@ import json
from collections import namedtuple
from unittest.mock import MagicMock, patch
from pydantic import ValidationError
import pytest
from src.graph.nodes import (
@@ -825,12 +826,102 @@ def test_human_feedback_node_json_decode_error_first_iteration(
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "__end__"
assert result.goto == "planner"
assert result.update["plan_iterations"] == 1
def test_human_feedback_node_model_validate_error(mock_state_base, mock_config):
# Plan.model_validate raises ValidationError, should enter error handling path
from pydantic import BaseModel
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
# Build a real ValidationError instance from pydantic
class DummyModel(BaseModel):
value: int
try:
DummyModel.model_validate({"value": "not_an_int"})
except ValidationError as validation_error:
raised_validation_error = validation_error
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
mock_configurable.enforce_web_search = False
mock_configurable.enable_web_search = True
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.Plan.model_validate",
side_effect=raised_validation_error,
),
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "planner"
assert result.update["plan_iterations"] == 1
def test_human_feedback_node_list_plan_runs_enforcement_after_normalization(
mock_state_base, mock_config
):
# Regression: when plan content is a list, normalization happens first,
# then validate_and_fix_plan must still run on the normalized dict.
raw_list_plan = [
{
"need_search": False,
"title": "Only Step",
"description": "Collect baseline info",
# intentionally missing step_type
}
]
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
state["current_plan"] = json.dumps({"content": [json.dumps(raw_list_plan)]})
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
mock_configurable.enforce_web_search = True
mock_configurable.enable_web_search = True
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
normalized_plan = result.update["current_plan"]
assert isinstance(normalized_plan, dict)
assert isinstance(normalized_plan.get("steps"), list)
assert len(normalized_plan["steps"]) == 1
# validate_and_fix_plan effects should be visible after normalization
assert normalized_plan["steps"][0]["step_type"] == "research"
assert normalized_plan["steps"][0]["need_search"] is True
def test_human_feedback_node_json_decode_error_second_iteration(