mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
fix: presever the local setting between frontend and backend (#670)
* fix: presever the local setting between frontend and backend * Added unit test for the state preservation * fix: passing the locale to the agent call * fix: apply the fix after code review
This commit is contained in:
@@ -84,6 +84,31 @@ def needs_clarification(state: dict) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def preserve_state_meta_fields(state: State) -> dict:
|
||||
"""
|
||||
Extract meta/config fields that should be preserved across state transitions.
|
||||
|
||||
These fields are critical for workflow continuity and should be explicitly
|
||||
included in all Command.update dicts to prevent them from reverting to defaults.
|
||||
|
||||
Args:
|
||||
state: Current state object
|
||||
|
||||
Returns:
|
||||
Dict of meta fields to preserve
|
||||
"""
|
||||
return {
|
||||
"locale": state.get("locale", "en-US"),
|
||||
"research_topic": state.get("research_topic", ""),
|
||||
"clarified_research_topic": state.get("clarified_research_topic", ""),
|
||||
"clarification_history": state.get("clarification_history", []),
|
||||
"enable_clarification": state.get("enable_clarification", False),
|
||||
"max_clarification_rounds": state.get("max_clarification_rounds", 3),
|
||||
"clarification_rounds": state.get("clarification_rounds", 0),
|
||||
"resources": state.get("resources", []),
|
||||
}
|
||||
|
||||
|
||||
def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
|
||||
"""
|
||||
Validate and fix a plan to ensure it meets requirements.
|
||||
@@ -217,7 +242,7 @@ def planner_node(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Command[Literal["human_feedback", "reporter"]]:
|
||||
"""Planner node that generate the full plan."""
|
||||
logger.info("Planner generating full plan")
|
||||
logger.info("Planner generating full plan with locale: %s", state.get("locale", "en-US"))
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
|
||||
@@ -266,7 +291,10 @@ def planner_node(
|
||||
|
||||
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
||||
if plan_iterations >= configurable.max_plan_iterations:
|
||||
return Command(goto="reporter")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="reporter"
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking:
|
||||
@@ -284,9 +312,15 @@ def planner_node(
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Planner response is not a valid JSON")
|
||||
if plan_iterations > 0:
|
||||
return Command(goto="reporter")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="reporter"
|
||||
)
|
||||
else:
|
||||
return Command(goto="__end__")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="__end__"
|
||||
)
|
||||
|
||||
# Validate and fix plan to ensure web search requirements are met
|
||||
if isinstance(curr_plan, dict):
|
||||
@@ -299,6 +333,7 @@ def planner_node(
|
||||
update={
|
||||
"messages": [AIMessage(content=full_response, name="planner")],
|
||||
"current_plan": new_plan,
|
||||
**preserve_state_meta_fields(state),
|
||||
},
|
||||
goto="reporter",
|
||||
)
|
||||
@@ -306,6 +341,7 @@ def planner_node(
|
||||
update={
|
||||
"messages": [AIMessage(content=full_response, name="planner")],
|
||||
"current_plan": full_response,
|
||||
**preserve_state_meta_fields(state),
|
||||
},
|
||||
goto="human_feedback",
|
||||
)
|
||||
@@ -323,7 +359,10 @@ def human_feedback_node(
|
||||
# Handle None or empty feedback
|
||||
if not feedback:
|
||||
logger.warning(f"Received empty or None feedback: {feedback}. Returning to planner for new plan.")
|
||||
return Command(goto="planner")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="planner"
|
||||
)
|
||||
|
||||
# Normalize feedback string
|
||||
feedback_normalized = str(feedback).strip().upper()
|
||||
@@ -336,6 +375,7 @@ def human_feedback_node(
|
||||
"messages": [
|
||||
HumanMessage(content=feedback, name="feedback"),
|
||||
],
|
||||
**preserve_state_meta_fields(state),
|
||||
},
|
||||
goto="planner",
|
||||
)
|
||||
@@ -343,7 +383,10 @@ def human_feedback_node(
|
||||
logger.info("Plan is accepted by user.")
|
||||
else:
|
||||
logger.warning(f"Unsupported feedback format: {feedback}. Please use '[ACCEPTED]' to accept or '[EDIT_PLAN]' to edit.")
|
||||
return Command(goto="planner")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="planner"
|
||||
)
|
||||
|
||||
# if the plan is accepted, run the following node
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
@@ -360,16 +403,29 @@ def human_feedback_node(
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Planner response is not a valid JSON")
|
||||
if plan_iterations > 1: # the plan_iterations is increased before this check
|
||||
return Command(goto="reporter")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="reporter"
|
||||
)
|
||||
else:
|
||||
return Command(goto="__end__")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="__end__"
|
||||
)
|
||||
|
||||
# Build update dict with safe locale handling
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
"plan_iterations": plan_iterations,
|
||||
**preserve_state_meta_fields(state),
|
||||
}
|
||||
|
||||
# Only override locale if new_plan provides a valid value, otherwise use preserved locale
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
"plan_iterations": plan_iterations,
|
||||
"locale": new_plan["locale"],
|
||||
},
|
||||
update=update_dict,
|
||||
goto=goto,
|
||||
)
|
||||
|
||||
@@ -408,6 +464,7 @@ def coordinator_node(
|
||||
|
||||
goto = "__end__"
|
||||
locale = state.get("locale", "en-US")
|
||||
logger.info(f"Coordinator locale: {locale}")
|
||||
research_topic = state.get("research_topic", "")
|
||||
|
||||
# Process tool calls for legacy mode
|
||||
@@ -421,9 +478,8 @@ def coordinator_node(
|
||||
logger.info("Handing off to planner")
|
||||
goto = "planner"
|
||||
|
||||
# Extract locale and research_topic if provided
|
||||
if tool_args.get("locale") and tool_args.get("research_topic"):
|
||||
locale = tool_args.get("locale")
|
||||
# Extract research_topic if provided
|
||||
if tool_args.get("research_topic"):
|
||||
research_topic = tool_args.get("research_topic")
|
||||
break
|
||||
|
||||
@@ -587,8 +643,6 @@ def coordinator_node(
|
||||
logger.info("Handing off to planner")
|
||||
goto = "planner"
|
||||
|
||||
# Extract locale if provided
|
||||
locale = tool_args.get("locale", locale)
|
||||
if not enable_clarification and tool_args.get("research_topic"):
|
||||
research_topic = tool_args["research_topic"]
|
||||
|
||||
@@ -725,7 +779,10 @@ async def _execute_agent_step(
|
||||
|
||||
if not current_step:
|
||||
logger.warning(f"[_execute_agent_step] No unexecuted step found in {len(current_plan.steps)} total steps")
|
||||
return Command(goto="research_team")
|
||||
return Command(
|
||||
update=preserve_state_meta_fields(state),
|
||||
goto="research_team"
|
||||
)
|
||||
|
||||
logger.info(f"[_execute_agent_step] Executing step: {current_step.title}, agent: {agent_name}")
|
||||
logger.debug(f"[_execute_agent_step] Completed steps so far: {len(completed_steps)}")
|
||||
@@ -834,6 +891,7 @@ async def _execute_agent_step(
|
||||
)
|
||||
],
|
||||
"observations": observations + [detailed_error],
|
||||
**preserve_state_meta_fields(state),
|
||||
},
|
||||
goto="research_team",
|
||||
)
|
||||
@@ -859,6 +917,7 @@ async def _execute_agent_step(
|
||||
)
|
||||
],
|
||||
"observations": observations + [response_content],
|
||||
**preserve_state_meta_fields(state),
|
||||
},
|
||||
goto="research_team",
|
||||
)
|
||||
|
||||
@@ -106,6 +106,8 @@ async def chat_stream(request: ChatRequest):
|
||||
# Check if MCP server configuration is enabled
|
||||
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
|
||||
|
||||
logger.debug(f"get the request locale : {request.locale}")
|
||||
|
||||
# Validate MCP settings if provided
|
||||
if request.mcp_settings and not mcp_enabled:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -675,7 +675,7 @@ def test_coordinator_node_with_tool_calls_locale_override(
|
||||
tool_calls = [
|
||||
{
|
||||
"name": "handoff_to_planner",
|
||||
"args": {"locale": "zh-CN", "research_topic": "test topic"},
|
||||
"args": {"locale": "auto", "research_topic": "test topic"},
|
||||
}
|
||||
]
|
||||
with (
|
||||
@@ -689,7 +689,7 @@ def test_coordinator_node_with_tool_calls_locale_override(
|
||||
|
||||
result = coordinator_node(mock_state_coordinator, MagicMock())
|
||||
assert result.goto == "planner"
|
||||
assert result.update["locale"] == "zh-CN"
|
||||
assert result.update["locale"] == "en-US"
|
||||
assert result.update["research_topic"] == "test topic"
|
||||
assert result.update["resources"] == ["resource1", "resource2"]
|
||||
assert result.update["resources"] == ["resource1", "resource2"]
|
||||
|
||||
241
tests/unit/graph/test_agent_locale_restoration.py
Normal file
241
tests/unit/graph/test_agent_locale_restoration.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for agent locale restoration after create_react_agent execution.
|
||||
|
||||
Tests that meta fields (especially locale) are properly restored after
|
||||
agent.ainvoke() returns, since create_react_agent creates a MessagesState
|
||||
subgraph that filters out custom fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.graph.nodes import preserve_state_meta_fields
|
||||
from src.graph.types import State
|
||||
|
||||
|
||||
class TestAgentLocaleRestoration:
|
||||
"""Test suite for locale restoration after agent execution."""
|
||||
|
||||
def test_locale_lost_in_agent_subgraph(self):
|
||||
"""
|
||||
Demonstrate the problem: agent subgraph filters out locale.
|
||||
|
||||
When create_react_agent creates a subgraph with MessagesState,
|
||||
it only returns messages, not custom fields.
|
||||
"""
|
||||
# Simulate agent behavior: only returns messages
|
||||
initial_state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Agent subgraph returns (like MessagesState would)
|
||||
agent_result = {
|
||||
"messages": ["agent response"],
|
||||
}
|
||||
|
||||
# Problem: locale is missing
|
||||
assert "locale" not in agent_result
|
||||
assert agent_result.get("locale") is None
|
||||
|
||||
def test_locale_restoration_after_agent(self):
|
||||
"""Test that locale can be restored after agent.ainvoke() returns."""
|
||||
initial_state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="test",
|
||||
)
|
||||
|
||||
# Simulate agent returning (MessagesState only)
|
||||
agent_result = {
|
||||
"messages": ["agent response"],
|
||||
}
|
||||
|
||||
# Apply restoration
|
||||
preserved = preserve_state_meta_fields(initial_state)
|
||||
agent_result.update(preserved)
|
||||
|
||||
# Verify restoration worked
|
||||
assert agent_result["locale"] == "zh-CN"
|
||||
assert agent_result["research_topic"] == "test"
|
||||
assert "messages" in agent_result
|
||||
|
||||
def test_all_meta_fields_restored(self):
|
||||
"""Test that all meta fields are restored, not just locale."""
|
||||
initial_state = State(
|
||||
messages=[],
|
||||
locale="en-US",
|
||||
research_topic="Original Topic",
|
||||
clarified_research_topic="Clarified Topic",
|
||||
clarification_history=["Q1", "A1"],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=2,
|
||||
resources=["resource1"],
|
||||
)
|
||||
|
||||
# Agent result
|
||||
agent_result = {"messages": ["response"]}
|
||||
agent_result.update(preserve_state_meta_fields(initial_state))
|
||||
|
||||
# All fields should be restored
|
||||
assert agent_result["locale"] == "en-US"
|
||||
assert agent_result["research_topic"] == "Original Topic"
|
||||
assert agent_result["clarified_research_topic"] == "Clarified Topic"
|
||||
assert agent_result["clarification_history"] == ["Q1", "A1"]
|
||||
assert agent_result["enable_clarification"] is True
|
||||
assert agent_result["max_clarification_rounds"] == 5
|
||||
assert agent_result["clarification_rounds"] == 2
|
||||
assert agent_result["resources"] == ["resource1"]
|
||||
|
||||
def test_locale_preservation_through_agent_cycle(self):
|
||||
"""Test the complete cycle: state in → agent → state out."""
|
||||
# Initial state with zh-CN locale
|
||||
initial_state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Step 1: Extract meta fields
|
||||
preserved = preserve_state_meta_fields(initial_state)
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
|
||||
# Step 2: Agent runs and returns only messages
|
||||
agent_result = {"messages": ["agent output"]}
|
||||
assert "locale" not in agent_result # Missing!
|
||||
|
||||
# Step 3: Restore meta fields
|
||||
agent_result.update(preserved)
|
||||
|
||||
# Step 4: Verify locale is restored
|
||||
assert agent_result["locale"] == "zh-CN"
|
||||
|
||||
# Step 5: Create new state with restored fields
|
||||
final_state = State(messages=agent_result["messages"], **preserved)
|
||||
assert final_state.get("locale") == "zh-CN"
|
||||
|
||||
def test_locale_not_auto_after_restoration(self):
|
||||
"""
|
||||
Test that locale is NOT "auto" after restoration.
|
||||
|
||||
This tests the specific bug: locale was becoming "auto"
|
||||
instead of the preserved "zh-CN" value.
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Agent returns without locale
|
||||
agent_result = {"messages": []}
|
||||
|
||||
# Before fix: locale would be "auto" (default behavior)
|
||||
# After fix: locale is preserved
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
assert agent_result.get("locale") == "zh-CN"
|
||||
assert agent_result.get("locale") != "auto"
|
||||
assert agent_result.get("locale") is not None
|
||||
|
||||
def test_chinese_locale_preserved(self):
|
||||
"""Test that Chinese locale specifically is preserved."""
|
||||
locales_to_test = ["zh-CN", "zh", "zh-Hans", "zh-Hant"]
|
||||
|
||||
for locale_value in locales_to_test:
|
||||
state = State(messages=[], locale=locale_value)
|
||||
agent_result = {"messages": []}
|
||||
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
assert agent_result["locale"] == locale_value, f"Failed for locale: {locale_value}"
|
||||
|
||||
def test_restoration_with_new_messages(self):
|
||||
"""Test that restoration works even when agent adds new messages."""
|
||||
state = State(messages=[], locale="zh-CN", research_topic="research")
|
||||
|
||||
# Agent processes and returns new messages
|
||||
agent_result = {
|
||||
"messages": ["message1", "message2", "message3"],
|
||||
}
|
||||
|
||||
# Restore meta fields
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
# Should have both new messages AND preserved meta fields
|
||||
assert len(agent_result["messages"]) == 3
|
||||
assert agent_result["locale"] == "zh-CN"
|
||||
assert agent_result["research_topic"] == "research"
|
||||
|
||||
def test_restoration_idempotent(self):
|
||||
"""Test that restoring meta fields multiple times doesn't cause issues."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
agent_result = {"messages": []}
|
||||
|
||||
# Apply restoration multiple times
|
||||
agent_result.update(preserved)
|
||||
agent_result.update(preserved)
|
||||
agent_result.update(preserved)
|
||||
|
||||
# Should still have correct locale (not corrupted)
|
||||
assert agent_result["locale"] == "en-US"
|
||||
assert len(agent_result) == 9 # messages + 8 meta fields
|
||||
|
||||
|
||||
class TestAgentLocaleRestorationScenarios:
|
||||
"""Real-world scenario tests for agent locale restoration."""
|
||||
|
||||
def test_researcher_agent_preserves_locale(self):
|
||||
"""
|
||||
Simulate researcher agent execution preserving locale.
|
||||
|
||||
Scenario:
|
||||
1. Researcher node receives state with locale="zh-CN"
|
||||
2. Calls agent.ainvoke() which returns only messages
|
||||
3. Restores locale before returning
|
||||
"""
|
||||
# State coming into researcher node
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="生产1公斤牛肉需要多少升水?",
|
||||
)
|
||||
|
||||
# Agent executes and returns
|
||||
agent_result = {
|
||||
"messages": ["Researcher analysis of water usage..."],
|
||||
}
|
||||
|
||||
# Apply restoration (what the fix does)
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
# Verify for next node
|
||||
assert agent_result["locale"] == "zh-CN" # ✓ Preserved!
|
||||
assert agent_result.get("locale") != "auto" # ✓ Not "auto"
|
||||
|
||||
def test_coder_agent_preserves_locale(self):
|
||||
"""Coder agent should also preserve locale."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
|
||||
agent_result = {"messages": ["Code generation result"]}
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
assert agent_result["locale"] == "en-US"
|
||||
|
||||
def test_locale_persists_across_multiple_agents(self):
|
||||
"""Test that locale persists through multiple agent calls."""
|
||||
locales = ["zh-CN", "en-US", "fr-FR"]
|
||||
|
||||
for locale in locales:
|
||||
# Initial state
|
||||
state = State(messages=[], locale=locale)
|
||||
preserved_1 = preserve_state_meta_fields(state)
|
||||
|
||||
# First agent
|
||||
result_1 = {"messages": ["agent1"]}
|
||||
result_1.update(preserved_1)
|
||||
|
||||
# Create state for second agent
|
||||
state_2 = State(messages=result_1["messages"], **preserved_1)
|
||||
preserved_2 = preserve_state_meta_fields(state_2)
|
||||
|
||||
# Second agent
|
||||
result_2 = {"messages": result_1["messages"] + ["agent2"]}
|
||||
result_2.update(preserved_2)
|
||||
|
||||
# Locale should persist
|
||||
assert result_2["locale"] == locale
|
||||
316
tests/unit/graph/test_human_feedback_locale_fix.py
Normal file
316
tests/unit/graph/test_human_feedback_locale_fix.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for the human_feedback_node locale fix.
|
||||
|
||||
Tests that the duplicate locale assignment issue is resolved:
|
||||
- Locale is safely retrieved from new_plan using .get() with fallback
|
||||
- If new_plan['locale'] doesn't exist, it doesn't cause a KeyError
|
||||
- If new_plan['locale'] is None or empty, the preserved state locale is used
|
||||
- If new_plan['locale'] has a valid value, it properly overrides the state locale
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.graph.nodes import preserve_state_meta_fields
|
||||
from src.graph.types import State
|
||||
from src.prompts.planner_model import Plan
|
||||
|
||||
|
||||
class TestHumanFeedbackLocaleFixture:
|
||||
"""Test suite for human_feedback_node locale safe handling."""
|
||||
|
||||
def test_preserve_state_meta_fields_no_keyerror(self):
|
||||
"""Test that preserve_state_meta_fields never raises KeyError."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
assert "locale" in preserved
|
||||
|
||||
def test_new_plan_without_locale_override(self):
|
||||
"""
|
||||
Test scenario: Plan doesn't override locale when not provided in override dict.
|
||||
|
||||
Before fix: Would set locale twice (duplicate assignment)
|
||||
After fix: Uses .get() safely and only overrides if value is truthy
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Simulate a plan that doesn't want to override the locale
|
||||
# (locale is in the plan for validation, but not in override dict)
|
||||
new_plan_dict = {"title": "Test", "thought": "Test", "steps": [], "locale": "zh-CN", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_dict),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Simulate a dict that doesn't have locale override (like when plan_dict is empty for override)
|
||||
plan_override = {} # No locale in override dict
|
||||
|
||||
# Only override locale if override dict provides a valid value
|
||||
if plan_override.get("locale"):
|
||||
update_dict["locale"] = plan_override["locale"]
|
||||
|
||||
# The preserved locale should be used when override doesn't provide one
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
|
||||
def test_new_plan_with_none_locale(self):
|
||||
"""
|
||||
Test scenario: new_plan has locale=None.
|
||||
|
||||
Before fix: Would try to set locale to None (but Plan requires it)
|
||||
After fix: Uses preserved state locale since new_plan.get("locale") is falsy
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# new_plan with None locale (won't validate, but test the logic)
|
||||
new_plan_attempt = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_attempt),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Simulate checking for None locale (if it somehow got set)
|
||||
new_plan_with_none = {"locale": None}
|
||||
# Only override if new_plan provides a VALID value
|
||||
if new_plan_with_none.get("locale"):
|
||||
update_dict["locale"] = new_plan_with_none["locale"]
|
||||
|
||||
# Should use preserved locale (zh-CN), not None
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
assert update_dict["locale"] is not None
|
||||
|
||||
def test_new_plan_with_empty_string_locale(self):
|
||||
"""
|
||||
Test scenario: new_plan has locale="" (empty string).
|
||||
|
||||
Before fix: Would try to set locale to "" (but Plan requires valid value)
|
||||
After fix: Uses preserved state locale since empty string is falsy
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# new_plan with valid locale
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Simulate checking for empty string locale
|
||||
new_plan_empty = {"locale": ""}
|
||||
# Only override if new_plan provides a VALID (truthy) value
|
||||
if new_plan_empty.get("locale"):
|
||||
update_dict["locale"] = new_plan_empty["locale"]
|
||||
|
||||
# Should use preserved locale (zh-CN), not empty string
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
assert update_dict["locale"] != ""
|
||||
|
||||
def test_new_plan_with_valid_locale_overrides(self):
|
||||
"""
|
||||
Test scenario: new_plan has valid locale="en-US".
|
||||
|
||||
Before fix: Would override with new_plan locale ✓ (worked)
|
||||
After fix: Should still properly override with valid locale
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# new_plan has a different valid locale
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Override if new_plan provides a VALID value
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# Should override with new_plan locale
|
||||
assert update_dict["locale"] == "en-US"
|
||||
assert update_dict["locale"] != "zh-CN"
|
||||
|
||||
def test_locale_field_not_duplicated(self):
|
||||
"""
|
||||
Test that locale field is not duplicated in the update dict.
|
||||
|
||||
Before fix: locale was set twice in the same dict
|
||||
After fix: locale is only set once
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Count how many times 'locale' is set
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved, # Sets locale once
|
||||
}
|
||||
|
||||
# Override locale only if new_plan provides valid value
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# Verify locale is in dict exactly once
|
||||
locale_count = sum(1 for k in update_dict.keys() if k == "locale")
|
||||
assert locale_count == 1
|
||||
assert update_dict["locale"] == "en-US" # Should be overridden
|
||||
|
||||
def test_all_meta_fields_preserved(self):
|
||||
"""
|
||||
Test that all 8 meta fields are preserved along with locale fix.
|
||||
|
||||
Ensures the fix doesn't break other meta field preservation.
|
||||
"""
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="Research",
|
||||
clarified_research_topic="Clarified",
|
||||
clarification_history=["Q1"],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=1,
|
||||
resources=["resource1"],
|
||||
)
|
||||
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# All 8 meta fields should be in preserved
|
||||
meta_fields = [
|
||||
"locale",
|
||||
"research_topic",
|
||||
"clarified_research_topic",
|
||||
"clarification_history",
|
||||
"enable_clarification",
|
||||
"max_clarification_rounds",
|
||||
"clarification_rounds",
|
||||
"resources",
|
||||
]
|
||||
|
||||
for field in meta_fields:
|
||||
assert field in preserved
|
||||
|
||||
# Build update dict
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Override locale if new_plan provides valid value
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# All meta fields should be in update_dict
|
||||
for field in meta_fields:
|
||||
assert field in update_dict
|
||||
|
||||
|
||||
class TestHumanFeedbackLocaleScenarios:
|
||||
"""Real-world scenarios for human_feedback_node locale handling."""
|
||||
|
||||
def test_scenario_chinese_locale_preserved_when_plan_has_no_locale(self):
|
||||
"""
|
||||
Scenario: User selected Chinese, plan preserves it.
|
||||
|
||||
Expected: Preserved Chinese locale should be used
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Plan from planner with required fields
|
||||
new_plan_json = {
|
||||
"title": "Research Plan",
|
||||
"thought": "...",
|
||||
"steps": [
|
||||
{
|
||||
"title": "Step 1",
|
||||
"description": "...",
|
||||
"need_search": True,
|
||||
"step_type": "research",
|
||||
}
|
||||
],
|
||||
"locale": "zh-CN",
|
||||
"has_enough_context": False,
|
||||
}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_json),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
if new_plan_json.get("locale"):
|
||||
update_dict["locale"] = new_plan_json["locale"]
|
||||
|
||||
# Chinese locale should be preserved
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
|
||||
def test_scenario_en_us_restored_even_if_plan_minimal(self):
|
||||
"""
|
||||
Scenario: Minimal plan with en-US locale.
|
||||
|
||||
Expected: Preserved en-US locale should survive
|
||||
"""
|
||||
state = State(messages=[], locale="en-US")
|
||||
|
||||
# Minimal plan with required fields
|
||||
new_plan_json = {"title": "Quick Plan", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_json),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
if new_plan_json.get("locale"):
|
||||
update_dict["locale"] = new_plan_json["locale"]
|
||||
|
||||
# en-US should survive
|
||||
assert update_dict["locale"] == "en-US"
|
||||
|
||||
def test_scenario_multiple_locale_updates_safe(self):
|
||||
"""
|
||||
Scenario: Multiple plan iterations with locale preservation.
|
||||
|
||||
Expected: Each iteration safely handles locale
|
||||
"""
|
||||
locales = ["zh-CN", "en-US", "fr-FR"]
|
||||
|
||||
for locale in locales:
|
||||
state = State(messages=[], locale=locale)
|
||||
new_plan = {"title": "Plan", "steps": [], "locale": locale, "has_enough_context": False}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# Each iteration should preserve its locale
|
||||
assert update_dict["locale"] == locale
|
||||
355
tests/unit/graph/test_state_preservation.py
Normal file
355
tests/unit/graph/test_state_preservation.py
Normal file
@@ -0,0 +1,355 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for state preservation functionality in graph nodes.
|
||||
|
||||
Tests the preserve_state_meta_fields() function and verifies that
|
||||
critical state fields (especially locale) are properly preserved
|
||||
across node state transitions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langgraph.types import Command
|
||||
|
||||
from src.graph.nodes import preserve_state_meta_fields
|
||||
from src.graph.types import State
|
||||
|
||||
|
||||
class TestPreserveStateMetaFields:
|
||||
"""Test suite for preserve_state_meta_fields() function."""
|
||||
|
||||
def test_preserve_all_fields_with_defaults(self):
|
||||
"""Test that all fields are preserved with default values when state is empty."""
|
||||
# Create a minimal state with only messages
|
||||
state = State(messages=[])
|
||||
|
||||
# Extract meta fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Verify all expected fields are present
|
||||
assert "locale" in preserved
|
||||
assert "research_topic" in preserved
|
||||
assert "clarified_research_topic" in preserved
|
||||
assert "clarification_history" in preserved
|
||||
assert "enable_clarification" in preserved
|
||||
assert "max_clarification_rounds" in preserved
|
||||
assert "clarification_rounds" in preserved
|
||||
assert "resources" in preserved
|
||||
|
||||
# Verify default values
|
||||
assert preserved["locale"] == "en-US"
|
||||
assert preserved["research_topic"] == ""
|
||||
assert preserved["clarified_research_topic"] == ""
|
||||
assert preserved["clarification_history"] == []
|
||||
assert preserved["enable_clarification"] is False
|
||||
assert preserved["max_clarification_rounds"] == 3
|
||||
assert preserved["clarification_rounds"] == 0
|
||||
assert preserved["resources"] == []
|
||||
|
||||
def test_preserve_locale_from_state(self):
|
||||
"""Test that locale is correctly preserved when set in state."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
|
||||
def test_preserve_locale_english(self):
|
||||
"""Test that English locale is preserved."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "en-US"
|
||||
|
||||
def test_preserve_locale_with_custom_value(self):
|
||||
"""Test that custom locale values are preserved."""
|
||||
state = State(messages=[], locale="fr-FR")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "fr-FR"
|
||||
|
||||
def test_preserve_research_topic(self):
|
||||
"""Test that research_topic is correctly preserved."""
|
||||
test_topic = "How to build sustainable cities"
|
||||
state = State(messages=[], research_topic=test_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["research_topic"] == test_topic
|
||||
|
||||
def test_preserve_clarified_research_topic(self):
|
||||
"""Test that clarified_research_topic is correctly preserved."""
|
||||
test_topic = "Sustainable urban development with focus on green spaces"
|
||||
state = State(messages=[], clarified_research_topic=test_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["clarified_research_topic"] == test_topic
|
||||
|
||||
def test_preserve_clarification_history(self):
|
||||
"""Test that clarification_history is correctly preserved."""
|
||||
history = ["Q: What aspects?", "A: Architecture and planning"]
|
||||
state = State(messages=[], clarification_history=history)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["clarification_history"] == history
|
||||
|
||||
def test_preserve_clarification_flags(self):
|
||||
"""Test that clarification flags are correctly preserved."""
|
||||
state = State(
|
||||
messages=[],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=2,
|
||||
)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["enable_clarification"] is True
|
||||
assert preserved["max_clarification_rounds"] == 5
|
||||
assert preserved["clarification_rounds"] == 2
|
||||
|
||||
def test_preserve_resources(self):
|
||||
"""Test that resources list is correctly preserved."""
|
||||
resources = [{"id": "1", "name": "Resource 1"}]
|
||||
state = State(messages=[], resources=resources)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["resources"] == resources
|
||||
|
||||
def test_preserve_all_fields_together(self):
|
||||
"""Test that all meta fields are preserved together correctly."""
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="原始查询",
|
||||
clarified_research_topic="澄清后的查询",
|
||||
clarification_history=["Q1", "A1", "Q2", "A2"],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=2,
|
||||
resources=["resource1"],
|
||||
)
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
assert preserved["research_topic"] == "原始查询"
|
||||
assert preserved["clarified_research_topic"] == "澄清后的查询"
|
||||
assert preserved["clarification_history"] == ["Q1", "A1", "Q2", "A2"]
|
||||
assert preserved["enable_clarification"] is True
|
||||
assert preserved["max_clarification_rounds"] == 5
|
||||
assert preserved["clarification_rounds"] == 2
|
||||
assert preserved["resources"] == ["resource1"]
|
||||
|
||||
def test_preserve_returns_dict_not_state_object(self):
|
||||
"""Test that preserve_state_meta_fields returns a dict."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert isinstance(preserved, dict)
|
||||
# Verify it's a plain dict with expected keys
|
||||
assert "locale" in preserved
|
||||
assert "research_topic" in preserved
|
||||
|
||||
def test_preserve_does_not_mutate_original_state(self):
|
||||
"""Test that calling preserve_state_meta_fields does not mutate the original state."""
|
||||
original_locale = "zh-CN"
|
||||
state = State(messages=[], locale=original_locale)
|
||||
original_state_copy = dict(state)
|
||||
|
||||
preserve_state_meta_fields(state)
|
||||
|
||||
# Verify state hasn't changed
|
||||
assert state["locale"] == original_locale
|
||||
assert dict(state) == original_state_copy
|
||||
|
||||
def test_preserve_with_none_values(self):
|
||||
"""Test that preserve handles None values gracefully."""
|
||||
state = State(messages=[], locale=None)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Should use default when value is None
|
||||
assert preserved["locale"] is None or preserved["locale"] == "en-US"
|
||||
|
||||
def test_preserve_empty_lists_preserved(self):
|
||||
"""Test that empty lists are preserved correctly."""
|
||||
state = State(
|
||||
messages=[], clarification_history=[], resources=[]
|
||||
)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["clarification_history"] == []
|
||||
assert preserved["resources"] == []
|
||||
|
||||
def test_preserve_count_of_fields(self):
|
||||
"""Test that exactly 8 fields are preserved."""
|
||||
state = State(messages=[])
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Should have exactly 8 meta fields
|
||||
assert len(preserved) == 8
|
||||
|
||||
def test_preserve_field_names(self):
|
||||
"""Test that all expected field names are present."""
|
||||
state = State(messages=[])
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
expected_fields = {
|
||||
"locale",
|
||||
"research_topic",
|
||||
"clarified_research_topic",
|
||||
"clarification_history",
|
||||
"enable_clarification",
|
||||
"max_clarification_rounds",
|
||||
"clarification_rounds",
|
||||
"resources",
|
||||
}
|
||||
|
||||
assert set(preserved.keys()) == expected_fields
|
||||
|
||||
|
||||
class TestStatePreservationInCommand:
|
||||
"""Test suite for using preserved state fields in Command objects."""
|
||||
|
||||
def test_command_update_with_preserved_fields(self):
|
||||
"""Test that preserved fields can be unpacked into Command.update."""
|
||||
state = State(messages=[], locale="zh-CN", research_topic="测试")
|
||||
|
||||
# This should not raise any errors
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
command_update = {
|
||||
"messages": [],
|
||||
**preserved,
|
||||
}
|
||||
|
||||
assert "locale" in command_update
|
||||
assert "research_topic" in command_update
|
||||
assert command_update["locale"] == "zh-CN"
|
||||
|
||||
def test_command_unpacking_syntax(self):
|
||||
"""Test that the unpacking syntax works correctly with preserved fields."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Simulate how it's used in actual nodes
|
||||
update_dict = {
|
||||
"messages": [],
|
||||
"current_plan": None,
|
||||
**preserved,
|
||||
"locale": "zh-CN",
|
||||
}
|
||||
|
||||
assert len(update_dict) >= 10 # 2 explicit + 8 preserved
|
||||
assert update_dict["locale"] == "zh-CN" # overridden value
|
||||
|
||||
|
||||
class TestLocalePreservationSpecific:
|
||||
"""Specific test cases for locale preservation (the main issue being fixed)."""
|
||||
|
||||
def test_locale_not_lost_in_transition(self):
|
||||
"""Test that locale is not lost when transitioning between nodes."""
|
||||
# Initial state from frontend with Chinese locale
|
||||
initial_state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Extract for first node transition
|
||||
preserved_1 = preserve_state_meta_fields(initial_state)
|
||||
|
||||
# Simulate state update from first node
|
||||
updated_state_1 = State(
|
||||
messages=[], **preserved_1
|
||||
)
|
||||
|
||||
# Extract for second node transition
|
||||
preserved_2 = preserve_state_meta_fields(updated_state_1)
|
||||
|
||||
# Locale should still be zh-CN after two transitions
|
||||
assert preserved_2["locale"] == "zh-CN"
|
||||
|
||||
def test_locale_chain_through_multiple_nodes(self):
|
||||
"""Test that locale survives through multiple node transitions."""
|
||||
initial_locale = "zh-CN"
|
||||
state = State(messages=[], locale=initial_locale)
|
||||
|
||||
# Simulate 5 node transitions
|
||||
for _ in range(5):
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
assert preserved["locale"] == initial_locale
|
||||
|
||||
# Create new state for next "node"
|
||||
state = State(messages=[], **preserved)
|
||||
|
||||
# After 5 transitions, locale should still be preserved
|
||||
assert state.get("locale") == initial_locale
|
||||
|
||||
def test_locale_with_other_fields_preserved_together(self):
|
||||
"""Test that locale is preserved correctly even when other fields change."""
|
||||
initial_state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="Original",
|
||||
clarification_rounds=0,
|
||||
)
|
||||
|
||||
preserved = preserve_state_meta_fields(initial_state)
|
||||
|
||||
# Verify locale is in preserved dict
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
assert preserved["research_topic"] == "Original"
|
||||
assert preserved["clarification_rounds"] == 0
|
||||
|
||||
# Create new state with preserved fields
|
||||
modified_state = State(
|
||||
messages=[],
|
||||
**preserved,
|
||||
)
|
||||
|
||||
# Locale should be preserved
|
||||
assert modified_state.get("locale") == "zh-CN"
|
||||
# Research topic should be preserved from original
|
||||
assert modified_state.get("research_topic") == "Original"
|
||||
assert modified_state.get("clarification_rounds") == 0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
def test_very_long_research_topic(self):
|
||||
"""Test preservation with very long research_topic."""
|
||||
long_topic = "a" * 10000
|
||||
state = State(messages=[], research_topic=long_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["research_topic"] == long_topic
|
||||
|
||||
def test_unicode_characters_in_topic(self):
|
||||
"""Test preservation with unicode characters."""
|
||||
unicode_topic = "中文测试 🌍 テスト 🧪"
|
||||
state = State(messages=[], research_topic=unicode_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["research_topic"] == unicode_topic
|
||||
|
||||
def test_special_characters_in_locale(self):
|
||||
"""Test preservation with special locale formats."""
|
||||
special_locales = ["zh-CN", "en-US", "pt-BR", "es-ES", "ja-JP"]
|
||||
|
||||
for locale in special_locales:
|
||||
state = State(messages=[], locale=locale)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
assert preserved["locale"] == locale
|
||||
|
||||
def test_large_clarification_history(self):
|
||||
"""Test preservation with large clarification_history."""
|
||||
large_history = [f"Q{i}: Question {i}" for i in range(100)]
|
||||
state = State(messages=[], clarification_history=large_history)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert len(preserved["clarification_history"]) == 100
|
||||
assert preserved["clarification_history"] == large_history
|
||||
|
||||
def test_max_clarification_rounds_boundary(self):
|
||||
"""Test preservation with boundary values for max_clarification_rounds."""
|
||||
test_cases = [0, 1, 3, 10, 100, 999]
|
||||
|
||||
for value in test_cases:
|
||||
state = State(messages=[], max_clarification_rounds=value)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
assert preserved["max_clarification_rounds"] == value
|
||||
Reference in New Issue
Block a user