fix: presever the local setting between frontend and backend (#670)

* fix: presever the local setting between frontend and backend

* Added unit test for the state preservation

* fix: passing the locale to the agent call

* fix: apply the fix after code review
This commit is contained in:
Willem Jiang
2025-10-28 21:45:29 +08:00
committed by GitHub
parent eb4c3b8ef6
commit 0415f622da
6 changed files with 994 additions and 21 deletions

View File

@@ -84,6 +84,31 @@ def needs_clarification(state: dict) -> bool:
)
def preserve_state_meta_fields(state: State) -> dict:
"""
Extract meta/config fields that should be preserved across state transitions.
These fields are critical for workflow continuity and should be explicitly
included in all Command.update dicts to prevent them from reverting to defaults.
Args:
state: Current state object
Returns:
Dict of meta fields to preserve
"""
return {
"locale": state.get("locale", "en-US"),
"research_topic": state.get("research_topic", ""),
"clarified_research_topic": state.get("clarified_research_topic", ""),
"clarification_history": state.get("clarification_history", []),
"enable_clarification": state.get("enable_clarification", False),
"max_clarification_rounds": state.get("max_clarification_rounds", 3),
"clarification_rounds": state.get("clarification_rounds", 0),
"resources": state.get("resources", []),
}
def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
"""
Validate and fix a plan to ensure it meets requirements.
@@ -217,7 +242,7 @@ def planner_node(
state: State, config: RunnableConfig
) -> Command[Literal["human_feedback", "reporter"]]:
"""Planner node that generate the full plan."""
logger.info("Planner generating full plan")
logger.info("Planner generating full plan with locale: %s", state.get("locale", "en-US"))
configurable = Configuration.from_runnable_config(config)
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
@@ -266,7 +291,10 @@ def planner_node(
# if the plan iterations is greater than the max plan iterations, return the reporter node
if plan_iterations >= configurable.max_plan_iterations:
return Command(goto="reporter")
return Command(
update=preserve_state_meta_fields(state),
goto="reporter"
)
full_response = ""
if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking:
@@ -284,9 +312,15 @@ def planner_node(
except json.JSONDecodeError:
logger.warning("Planner response is not a valid JSON")
if plan_iterations > 0:
return Command(goto="reporter")
return Command(
update=preserve_state_meta_fields(state),
goto="reporter"
)
else:
return Command(goto="__end__")
return Command(
update=preserve_state_meta_fields(state),
goto="__end__"
)
# Validate and fix plan to ensure web search requirements are met
if isinstance(curr_plan, dict):
@@ -299,6 +333,7 @@ def planner_node(
update={
"messages": [AIMessage(content=full_response, name="planner")],
"current_plan": new_plan,
**preserve_state_meta_fields(state),
},
goto="reporter",
)
@@ -306,6 +341,7 @@ def planner_node(
update={
"messages": [AIMessage(content=full_response, name="planner")],
"current_plan": full_response,
**preserve_state_meta_fields(state),
},
goto="human_feedback",
)
@@ -323,7 +359,10 @@ def human_feedback_node(
# Handle None or empty feedback
if not feedback:
logger.warning(f"Received empty or None feedback: {feedback}. Returning to planner for new plan.")
return Command(goto="planner")
return Command(
update=preserve_state_meta_fields(state),
goto="planner"
)
# Normalize feedback string
feedback_normalized = str(feedback).strip().upper()
@@ -336,6 +375,7 @@ def human_feedback_node(
"messages": [
HumanMessage(content=feedback, name="feedback"),
],
**preserve_state_meta_fields(state),
},
goto="planner",
)
@@ -343,7 +383,10 @@ def human_feedback_node(
logger.info("Plan is accepted by user.")
else:
logger.warning(f"Unsupported feedback format: {feedback}. Please use '[ACCEPTED]' to accept or '[EDIT_PLAN]' to edit.")
return Command(goto="planner")
return Command(
update=preserve_state_meta_fields(state),
goto="planner"
)
# if the plan is accepted, run the following node
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
@@ -360,16 +403,29 @@ def human_feedback_node(
except json.JSONDecodeError:
logger.warning("Planner response is not a valid JSON")
if plan_iterations > 1: # the plan_iterations is increased before this check
return Command(goto="reporter")
return Command(
update=preserve_state_meta_fields(state),
goto="reporter"
)
else:
return Command(goto="__end__")
return Command(
update=preserve_state_meta_fields(state),
goto="__end__"
)
# Build update dict with safe locale handling
update_dict = {
"current_plan": Plan.model_validate(new_plan),
"plan_iterations": plan_iterations,
**preserve_state_meta_fields(state),
}
# Only override locale if new_plan provides a valid value, otherwise use preserved locale
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
return Command(
update={
"current_plan": Plan.model_validate(new_plan),
"plan_iterations": plan_iterations,
"locale": new_plan["locale"],
},
update=update_dict,
goto=goto,
)
@@ -408,6 +464,7 @@ def coordinator_node(
goto = "__end__"
locale = state.get("locale", "en-US")
logger.info(f"Coordinator locale: {locale}")
research_topic = state.get("research_topic", "")
# Process tool calls for legacy mode
@@ -421,9 +478,8 @@ def coordinator_node(
logger.info("Handing off to planner")
goto = "planner"
# Extract locale and research_topic if provided
if tool_args.get("locale") and tool_args.get("research_topic"):
locale = tool_args.get("locale")
# Extract research_topic if provided
if tool_args.get("research_topic"):
research_topic = tool_args.get("research_topic")
break
@@ -587,8 +643,6 @@ def coordinator_node(
logger.info("Handing off to planner")
goto = "planner"
# Extract locale if provided
locale = tool_args.get("locale", locale)
if not enable_clarification and tool_args.get("research_topic"):
research_topic = tool_args["research_topic"]
@@ -725,7 +779,10 @@ async def _execute_agent_step(
if not current_step:
logger.warning(f"[_execute_agent_step] No unexecuted step found in {len(current_plan.steps)} total steps")
return Command(goto="research_team")
return Command(
update=preserve_state_meta_fields(state),
goto="research_team"
)
logger.info(f"[_execute_agent_step] Executing step: {current_step.title}, agent: {agent_name}")
logger.debug(f"[_execute_agent_step] Completed steps so far: {len(completed_steps)}")
@@ -834,6 +891,7 @@ async def _execute_agent_step(
)
],
"observations": observations + [detailed_error],
**preserve_state_meta_fields(state),
},
goto="research_team",
)
@@ -859,6 +917,7 @@ async def _execute_agent_step(
)
],
"observations": observations + [response_content],
**preserve_state_meta_fields(state),
},
goto="research_team",
)

View File

@@ -106,6 +106,8 @@ async def chat_stream(request: ChatRequest):
# Check if MCP server configuration is enabled
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
logger.debug(f"get the request locale : {request.locale}")
# Validate MCP settings if provided
if request.mcp_settings and not mcp_enabled:
raise HTTPException(

View File

@@ -675,7 +675,7 @@ def test_coordinator_node_with_tool_calls_locale_override(
tool_calls = [
{
"name": "handoff_to_planner",
"args": {"locale": "zh-CN", "research_topic": "test topic"},
"args": {"locale": "auto", "research_topic": "test topic"},
}
]
with (
@@ -689,7 +689,7 @@ def test_coordinator_node_with_tool_calls_locale_override(
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "zh-CN"
assert result.update["locale"] == "en-US"
assert result.update["research_topic"] == "test topic"
assert result.update["resources"] == ["resource1", "resource2"]
assert result.update["resources"] == ["resource1", "resource2"]

View File

@@ -0,0 +1,241 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for agent locale restoration after create_react_agent execution.
Tests that meta fields (especially locale) are properly restored after
agent.ainvoke() returns, since create_react_agent creates a MessagesState
subgraph that filters out custom fields.
"""
import pytest
from src.graph.nodes import preserve_state_meta_fields
from src.graph.types import State
class TestAgentLocaleRestoration:
"""Test suite for locale restoration after agent execution."""
def test_locale_lost_in_agent_subgraph(self):
"""
Demonstrate the problem: agent subgraph filters out locale.
When create_react_agent creates a subgraph with MessagesState,
it only returns messages, not custom fields.
"""
# Simulate agent behavior: only returns messages
initial_state = State(messages=[], locale="zh-CN")
# Agent subgraph returns (like MessagesState would)
agent_result = {
"messages": ["agent response"],
}
# Problem: locale is missing
assert "locale" not in agent_result
assert agent_result.get("locale") is None
def test_locale_restoration_after_agent(self):
"""Test that locale can be restored after agent.ainvoke() returns."""
initial_state = State(
messages=[],
locale="zh-CN",
research_topic="test",
)
# Simulate agent returning (MessagesState only)
agent_result = {
"messages": ["agent response"],
}
# Apply restoration
preserved = preserve_state_meta_fields(initial_state)
agent_result.update(preserved)
# Verify restoration worked
assert agent_result["locale"] == "zh-CN"
assert agent_result["research_topic"] == "test"
assert "messages" in agent_result
def test_all_meta_fields_restored(self):
"""Test that all meta fields are restored, not just locale."""
initial_state = State(
messages=[],
locale="en-US",
research_topic="Original Topic",
clarified_research_topic="Clarified Topic",
clarification_history=["Q1", "A1"],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=2,
resources=["resource1"],
)
# Agent result
agent_result = {"messages": ["response"]}
agent_result.update(preserve_state_meta_fields(initial_state))
# All fields should be restored
assert agent_result["locale"] == "en-US"
assert agent_result["research_topic"] == "Original Topic"
assert agent_result["clarified_research_topic"] == "Clarified Topic"
assert agent_result["clarification_history"] == ["Q1", "A1"]
assert agent_result["enable_clarification"] is True
assert agent_result["max_clarification_rounds"] == 5
assert agent_result["clarification_rounds"] == 2
assert agent_result["resources"] == ["resource1"]
def test_locale_preservation_through_agent_cycle(self):
"""Test the complete cycle: state in → agent → state out."""
# Initial state with zh-CN locale
initial_state = State(messages=[], locale="zh-CN")
# Step 1: Extract meta fields
preserved = preserve_state_meta_fields(initial_state)
assert preserved["locale"] == "zh-CN"
# Step 2: Agent runs and returns only messages
agent_result = {"messages": ["agent output"]}
assert "locale" not in agent_result # Missing!
# Step 3: Restore meta fields
agent_result.update(preserved)
# Step 4: Verify locale is restored
assert agent_result["locale"] == "zh-CN"
# Step 5: Create new state with restored fields
final_state = State(messages=agent_result["messages"], **preserved)
assert final_state.get("locale") == "zh-CN"
def test_locale_not_auto_after_restoration(self):
"""
Test that locale is NOT "auto" after restoration.
This tests the specific bug: locale was becoming "auto"
instead of the preserved "zh-CN" value.
"""
state = State(messages=[], locale="zh-CN")
# Agent returns without locale
agent_result = {"messages": []}
# Before fix: locale would be "auto" (default behavior)
# After fix: locale is preserved
agent_result.update(preserve_state_meta_fields(state))
assert agent_result.get("locale") == "zh-CN"
assert agent_result.get("locale") != "auto"
assert agent_result.get("locale") is not None
def test_chinese_locale_preserved(self):
"""Test that Chinese locale specifically is preserved."""
locales_to_test = ["zh-CN", "zh", "zh-Hans", "zh-Hant"]
for locale_value in locales_to_test:
state = State(messages=[], locale=locale_value)
agent_result = {"messages": []}
agent_result.update(preserve_state_meta_fields(state))
assert agent_result["locale"] == locale_value, f"Failed for locale: {locale_value}"
def test_restoration_with_new_messages(self):
"""Test that restoration works even when agent adds new messages."""
state = State(messages=[], locale="zh-CN", research_topic="research")
# Agent processes and returns new messages
agent_result = {
"messages": ["message1", "message2", "message3"],
}
# Restore meta fields
agent_result.update(preserve_state_meta_fields(state))
# Should have both new messages AND preserved meta fields
assert len(agent_result["messages"]) == 3
assert agent_result["locale"] == "zh-CN"
assert agent_result["research_topic"] == "research"
def test_restoration_idempotent(self):
"""Test that restoring meta fields multiple times doesn't cause issues."""
state = State(messages=[], locale="en-US")
preserved = preserve_state_meta_fields(state)
agent_result = {"messages": []}
# Apply restoration multiple times
agent_result.update(preserved)
agent_result.update(preserved)
agent_result.update(preserved)
# Should still have correct locale (not corrupted)
assert agent_result["locale"] == "en-US"
assert len(agent_result) == 9 # messages + 8 meta fields
class TestAgentLocaleRestorationScenarios:
"""Real-world scenario tests for agent locale restoration."""
def test_researcher_agent_preserves_locale(self):
"""
Simulate researcher agent execution preserving locale.
Scenario:
1. Researcher node receives state with locale="zh-CN"
2. Calls agent.ainvoke() which returns only messages
3. Restores locale before returning
"""
# State coming into researcher node
state = State(
messages=[],
locale="zh-CN",
research_topic="生产1公斤牛肉需要多少升水",
)
# Agent executes and returns
agent_result = {
"messages": ["Researcher analysis of water usage..."],
}
# Apply restoration (what the fix does)
agent_result.update(preserve_state_meta_fields(state))
# Verify for next node
assert agent_result["locale"] == "zh-CN" # ✓ Preserved!
assert agent_result.get("locale") != "auto" # ✓ Not "auto"
def test_coder_agent_preserves_locale(self):
"""Coder agent should also preserve locale."""
state = State(messages=[], locale="en-US")
agent_result = {"messages": ["Code generation result"]}
agent_result.update(preserve_state_meta_fields(state))
assert agent_result["locale"] == "en-US"
def test_locale_persists_across_multiple_agents(self):
"""Test that locale persists through multiple agent calls."""
locales = ["zh-CN", "en-US", "fr-FR"]
for locale in locales:
# Initial state
state = State(messages=[], locale=locale)
preserved_1 = preserve_state_meta_fields(state)
# First agent
result_1 = {"messages": ["agent1"]}
result_1.update(preserved_1)
# Create state for second agent
state_2 = State(messages=result_1["messages"], **preserved_1)
preserved_2 = preserve_state_meta_fields(state_2)
# Second agent
result_2 = {"messages": result_1["messages"] + ["agent2"]}
result_2.update(preserved_2)
# Locale should persist
assert result_2["locale"] == locale

View File

@@ -0,0 +1,316 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for the human_feedback_node locale fix.
Tests that the duplicate locale assignment issue is resolved:
- Locale is safely retrieved from new_plan using .get() with fallback
- If new_plan['locale'] doesn't exist, it doesn't cause a KeyError
- If new_plan['locale'] is None or empty, the preserved state locale is used
- If new_plan['locale'] has a valid value, it properly overrides the state locale
"""
import pytest
from src.graph.nodes import preserve_state_meta_fields
from src.graph.types import State
from src.prompts.planner_model import Plan
class TestHumanFeedbackLocaleFixture:
"""Test suite for human_feedback_node locale safe handling."""
def test_preserve_state_meta_fields_no_keyerror(self):
"""Test that preserve_state_meta_fields never raises KeyError."""
state = State(messages=[], locale="zh-CN")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "zh-CN"
assert "locale" in preserved
def test_new_plan_without_locale_override(self):
"""
Test scenario: Plan doesn't override locale when not provided in override dict.
Before fix: Would set locale twice (duplicate assignment)
After fix: Uses .get() safely and only overrides if value is truthy
"""
state = State(messages=[], locale="zh-CN")
# Simulate a plan that doesn't want to override the locale
# (locale is in the plan for validation, but not in override dict)
new_plan_dict = {"title": "Test", "thought": "Test", "steps": [], "locale": "zh-CN", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan_dict),
**preserved,
}
# Simulate a dict that doesn't have locale override (like when plan_dict is empty for override)
plan_override = {} # No locale in override dict
# Only override locale if override dict provides a valid value
if plan_override.get("locale"):
update_dict["locale"] = plan_override["locale"]
# The preserved locale should be used when override doesn't provide one
assert update_dict["locale"] == "zh-CN"
def test_new_plan_with_none_locale(self):
"""
Test scenario: new_plan has locale=None.
Before fix: Would try to set locale to None (but Plan requires it)
After fix: Uses preserved state locale since new_plan.get("locale") is falsy
"""
state = State(messages=[], locale="zh-CN")
# new_plan with None locale (won't validate, but test the logic)
new_plan_attempt = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan_attempt),
**preserved,
}
# Simulate checking for None locale (if it somehow got set)
new_plan_with_none = {"locale": None}
# Only override if new_plan provides a VALID value
if new_plan_with_none.get("locale"):
update_dict["locale"] = new_plan_with_none["locale"]
# Should use preserved locale (zh-CN), not None
assert update_dict["locale"] == "zh-CN"
assert update_dict["locale"] is not None
def test_new_plan_with_empty_string_locale(self):
"""
Test scenario: new_plan has locale="" (empty string).
Before fix: Would try to set locale to "" (but Plan requires valid value)
After fix: Uses preserved state locale since empty string is falsy
"""
state = State(messages=[], locale="zh-CN")
# new_plan with valid locale
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
# Simulate checking for empty string locale
new_plan_empty = {"locale": ""}
# Only override if new_plan provides a VALID (truthy) value
if new_plan_empty.get("locale"):
update_dict["locale"] = new_plan_empty["locale"]
# Should use preserved locale (zh-CN), not empty string
assert update_dict["locale"] == "zh-CN"
assert update_dict["locale"] != ""
def test_new_plan_with_valid_locale_overrides(self):
"""
Test scenario: new_plan has valid locale="en-US".
Before fix: Would override with new_plan locale ✓ (worked)
After fix: Should still properly override with valid locale
"""
state = State(messages=[], locale="zh-CN")
# new_plan has a different valid locale
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
# Override if new_plan provides a VALID value
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# Should override with new_plan locale
assert update_dict["locale"] == "en-US"
assert update_dict["locale"] != "zh-CN"
def test_locale_field_not_duplicated(self):
"""
Test that locale field is not duplicated in the update dict.
Before fix: locale was set twice in the same dict
After fix: locale is only set once
"""
state = State(messages=[], locale="zh-CN")
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
# Count how many times 'locale' is set
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved, # Sets locale once
}
# Override locale only if new_plan provides valid value
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# Verify locale is in dict exactly once
locale_count = sum(1 for k in update_dict.keys() if k == "locale")
assert locale_count == 1
assert update_dict["locale"] == "en-US" # Should be overridden
def test_all_meta_fields_preserved(self):
"""
Test that all 8 meta fields are preserved along with locale fix.
Ensures the fix doesn't break other meta field preservation.
"""
state = State(
messages=[],
locale="zh-CN",
research_topic="Research",
clarified_research_topic="Clarified",
clarification_history=["Q1"],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=1,
resources=["resource1"],
)
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
# All 8 meta fields should be in preserved
meta_fields = [
"locale",
"research_topic",
"clarified_research_topic",
"clarification_history",
"enable_clarification",
"max_clarification_rounds",
"clarification_rounds",
"resources",
]
for field in meta_fields:
assert field in preserved
# Build update dict
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
# Override locale if new_plan provides valid value
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# All meta fields should be in update_dict
for field in meta_fields:
assert field in update_dict
class TestHumanFeedbackLocaleScenarios:
"""Real-world scenarios for human_feedback_node locale handling."""
def test_scenario_chinese_locale_preserved_when_plan_has_no_locale(self):
"""
Scenario: User selected Chinese, plan preserves it.
Expected: Preserved Chinese locale should be used
"""
state = State(messages=[], locale="zh-CN")
# Plan from planner with required fields
new_plan_json = {
"title": "Research Plan",
"thought": "...",
"steps": [
{
"title": "Step 1",
"description": "...",
"need_search": True,
"step_type": "research",
}
],
"locale": "zh-CN",
"has_enough_context": False,
}
preserved = preserve_state_meta_fields(state)
update_dict = {
"current_plan": Plan.model_validate(new_plan_json),
**preserved,
}
if new_plan_json.get("locale"):
update_dict["locale"] = new_plan_json["locale"]
# Chinese locale should be preserved
assert update_dict["locale"] == "zh-CN"
def test_scenario_en_us_restored_even_if_plan_minimal(self):
"""
Scenario: Minimal plan with en-US locale.
Expected: Preserved en-US locale should survive
"""
state = State(messages=[], locale="en-US")
# Minimal plan with required fields
new_plan_json = {"title": "Quick Plan", "steps": [], "locale": "en-US", "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
update_dict = {
"current_plan": Plan.model_validate(new_plan_json),
**preserved,
}
if new_plan_json.get("locale"):
update_dict["locale"] = new_plan_json["locale"]
# en-US should survive
assert update_dict["locale"] == "en-US"
def test_scenario_multiple_locale_updates_safe(self):
"""
Scenario: Multiple plan iterations with locale preservation.
Expected: Each iteration safely handles locale
"""
locales = ["zh-CN", "en-US", "fr-FR"]
for locale in locales:
state = State(messages=[], locale=locale)
new_plan = {"title": "Plan", "steps": [], "locale": locale, "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# Each iteration should preserve its locale
assert update_dict["locale"] == locale

View File

@@ -0,0 +1,355 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for state preservation functionality in graph nodes.
Tests the preserve_state_meta_fields() function and verifies that
critical state fields (especially locale) are properly preserved
across node state transitions.
"""
import pytest
from langgraph.types import Command
from src.graph.nodes import preserve_state_meta_fields
from src.graph.types import State
class TestPreserveStateMetaFields:
"""Test suite for preserve_state_meta_fields() function."""
def test_preserve_all_fields_with_defaults(self):
"""Test that all fields are preserved with default values when state is empty."""
# Create a minimal state with only messages
state = State(messages=[])
# Extract meta fields
preserved = preserve_state_meta_fields(state)
# Verify all expected fields are present
assert "locale" in preserved
assert "research_topic" in preserved
assert "clarified_research_topic" in preserved
assert "clarification_history" in preserved
assert "enable_clarification" in preserved
assert "max_clarification_rounds" in preserved
assert "clarification_rounds" in preserved
assert "resources" in preserved
# Verify default values
assert preserved["locale"] == "en-US"
assert preserved["research_topic"] == ""
assert preserved["clarified_research_topic"] == ""
assert preserved["clarification_history"] == []
assert preserved["enable_clarification"] is False
assert preserved["max_clarification_rounds"] == 3
assert preserved["clarification_rounds"] == 0
assert preserved["resources"] == []
def test_preserve_locale_from_state(self):
"""Test that locale is correctly preserved when set in state."""
state = State(messages=[], locale="zh-CN")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "zh-CN"
def test_preserve_locale_english(self):
"""Test that English locale is preserved."""
state = State(messages=[], locale="en-US")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "en-US"
def test_preserve_locale_with_custom_value(self):
"""Test that custom locale values are preserved."""
state = State(messages=[], locale="fr-FR")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "fr-FR"
def test_preserve_research_topic(self):
"""Test that research_topic is correctly preserved."""
test_topic = "How to build sustainable cities"
state = State(messages=[], research_topic=test_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["research_topic"] == test_topic
def test_preserve_clarified_research_topic(self):
"""Test that clarified_research_topic is correctly preserved."""
test_topic = "Sustainable urban development with focus on green spaces"
state = State(messages=[], clarified_research_topic=test_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["clarified_research_topic"] == test_topic
def test_preserve_clarification_history(self):
"""Test that clarification_history is correctly preserved."""
history = ["Q: What aspects?", "A: Architecture and planning"]
state = State(messages=[], clarification_history=history)
preserved = preserve_state_meta_fields(state)
assert preserved["clarification_history"] == history
def test_preserve_clarification_flags(self):
"""Test that clarification flags are correctly preserved."""
state = State(
messages=[],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=2,
)
preserved = preserve_state_meta_fields(state)
assert preserved["enable_clarification"] is True
assert preserved["max_clarification_rounds"] == 5
assert preserved["clarification_rounds"] == 2
def test_preserve_resources(self):
"""Test that resources list is correctly preserved."""
resources = [{"id": "1", "name": "Resource 1"}]
state = State(messages=[], resources=resources)
preserved = preserve_state_meta_fields(state)
assert preserved["resources"] == resources
def test_preserve_all_fields_together(self):
"""Test that all meta fields are preserved together correctly."""
state = State(
messages=[],
locale="zh-CN",
research_topic="原始查询",
clarified_research_topic="澄清后的查询",
clarification_history=["Q1", "A1", "Q2", "A2"],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=2,
resources=["resource1"],
)
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "zh-CN"
assert preserved["research_topic"] == "原始查询"
assert preserved["clarified_research_topic"] == "澄清后的查询"
assert preserved["clarification_history"] == ["Q1", "A1", "Q2", "A2"]
assert preserved["enable_clarification"] is True
assert preserved["max_clarification_rounds"] == 5
assert preserved["clarification_rounds"] == 2
assert preserved["resources"] == ["resource1"]
def test_preserve_returns_dict_not_state_object(self):
"""Test that preserve_state_meta_fields returns a dict."""
state = State(messages=[], locale="zh-CN")
preserved = preserve_state_meta_fields(state)
assert isinstance(preserved, dict)
# Verify it's a plain dict with expected keys
assert "locale" in preserved
assert "research_topic" in preserved
def test_preserve_does_not_mutate_original_state(self):
"""Test that calling preserve_state_meta_fields does not mutate the original state."""
original_locale = "zh-CN"
state = State(messages=[], locale=original_locale)
original_state_copy = dict(state)
preserve_state_meta_fields(state)
# Verify state hasn't changed
assert state["locale"] == original_locale
assert dict(state) == original_state_copy
def test_preserve_with_none_values(self):
"""Test that preserve handles None values gracefully."""
state = State(messages=[], locale=None)
preserved = preserve_state_meta_fields(state)
# Should use default when value is None
assert preserved["locale"] is None or preserved["locale"] == "en-US"
def test_preserve_empty_lists_preserved(self):
"""Test that empty lists are preserved correctly."""
state = State(
messages=[], clarification_history=[], resources=[]
)
preserved = preserve_state_meta_fields(state)
assert preserved["clarification_history"] == []
assert preserved["resources"] == []
def test_preserve_count_of_fields(self):
"""Test that exactly 8 fields are preserved."""
state = State(messages=[])
preserved = preserve_state_meta_fields(state)
# Should have exactly 8 meta fields
assert len(preserved) == 8
def test_preserve_field_names(self):
"""Test that all expected field names are present."""
state = State(messages=[])
preserved = preserve_state_meta_fields(state)
expected_fields = {
"locale",
"research_topic",
"clarified_research_topic",
"clarification_history",
"enable_clarification",
"max_clarification_rounds",
"clarification_rounds",
"resources",
}
assert set(preserved.keys()) == expected_fields
class TestStatePreservationInCommand:
"""Test suite for using preserved state fields in Command objects."""
def test_command_update_with_preserved_fields(self):
"""Test that preserved fields can be unpacked into Command.update."""
state = State(messages=[], locale="zh-CN", research_topic="测试")
# This should not raise any errors
preserved = preserve_state_meta_fields(state)
command_update = {
"messages": [],
**preserved,
}
assert "locale" in command_update
assert "research_topic" in command_update
assert command_update["locale"] == "zh-CN"
def test_command_unpacking_syntax(self):
"""Test that the unpacking syntax works correctly with preserved fields."""
state = State(messages=[], locale="en-US")
preserved = preserve_state_meta_fields(state)
# Simulate how it's used in actual nodes
update_dict = {
"messages": [],
"current_plan": None,
**preserved,
"locale": "zh-CN",
}
assert len(update_dict) >= 10 # 2 explicit + 8 preserved
assert update_dict["locale"] == "zh-CN" # overridden value
class TestLocalePreservationSpecific:
"""Specific test cases for locale preservation (the main issue being fixed)."""
def test_locale_not_lost_in_transition(self):
"""Test that locale is not lost when transitioning between nodes."""
# Initial state from frontend with Chinese locale
initial_state = State(messages=[], locale="zh-CN")
# Extract for first node transition
preserved_1 = preserve_state_meta_fields(initial_state)
# Simulate state update from first node
updated_state_1 = State(
messages=[], **preserved_1
)
# Extract for second node transition
preserved_2 = preserve_state_meta_fields(updated_state_1)
# Locale should still be zh-CN after two transitions
assert preserved_2["locale"] == "zh-CN"
def test_locale_chain_through_multiple_nodes(self):
"""Test that locale survives through multiple node transitions."""
initial_locale = "zh-CN"
state = State(messages=[], locale=initial_locale)
# Simulate 5 node transitions
for _ in range(5):
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == initial_locale
# Create new state for next "node"
state = State(messages=[], **preserved)
# After 5 transitions, locale should still be preserved
assert state.get("locale") == initial_locale
def test_locale_with_other_fields_preserved_together(self):
"""Test that locale is preserved correctly even when other fields change."""
initial_state = State(
messages=[],
locale="zh-CN",
research_topic="Original",
clarification_rounds=0,
)
preserved = preserve_state_meta_fields(initial_state)
# Verify locale is in preserved dict
assert preserved["locale"] == "zh-CN"
assert preserved["research_topic"] == "Original"
assert preserved["clarification_rounds"] == 0
# Create new state with preserved fields
modified_state = State(
messages=[],
**preserved,
)
# Locale should be preserved
assert modified_state.get("locale") == "zh-CN"
# Research topic should be preserved from original
assert modified_state.get("research_topic") == "Original"
assert modified_state.get("clarification_rounds") == 0
class TestEdgeCases:
"""Test edge cases and boundary conditions."""
def test_very_long_research_topic(self):
"""Test preservation with very long research_topic."""
long_topic = "a" * 10000
state = State(messages=[], research_topic=long_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["research_topic"] == long_topic
def test_unicode_characters_in_topic(self):
"""Test preservation with unicode characters."""
unicode_topic = "中文测试 🌍 テスト 🧪"
state = State(messages=[], research_topic=unicode_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["research_topic"] == unicode_topic
def test_special_characters_in_locale(self):
"""Test preservation with special locale formats."""
special_locales = ["zh-CN", "en-US", "pt-BR", "es-ES", "ja-JP"]
for locale in special_locales:
state = State(messages=[], locale=locale)
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == locale
def test_large_clarification_history(self):
"""Test preservation with large clarification_history."""
large_history = [f"Q{i}: Question {i}" for i in range(100)]
state = State(messages=[], clarification_history=large_history)
preserved = preserve_state_meta_fields(state)
assert len(preserved["clarification_history"]) == 100
assert preserved["clarification_history"] == large_history
def test_max_clarification_rounds_boundary(self):
"""Test preservation with boundary values for max_clarification_rounds."""
test_cases = [0, 1, 3, 10, 100, 999]
for value in test_cases:
state = State(messages=[], max_clarification_rounds=value)
preserved = preserve_state_meta_fields(state)
assert preserved["max_clarification_rounds"] == value