diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 785adde..34325a5 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -84,6 +84,31 @@ def needs_clarification(state: dict) -> bool: ) +def preserve_state_meta_fields(state: State) -> dict: + """ + Extract meta/config fields that should be preserved across state transitions. + + These fields are critical for workflow continuity and should be explicitly + included in all Command.update dicts to prevent them from reverting to defaults. + + Args: + state: Current state object + + Returns: + Dict of meta fields to preserve + """ + return { + "locale": state.get("locale", "en-US"), + "research_topic": state.get("research_topic", ""), + "clarified_research_topic": state.get("clarified_research_topic", ""), + "clarification_history": state.get("clarification_history", []), + "enable_clarification": state.get("enable_clarification", False), + "max_clarification_rounds": state.get("max_clarification_rounds", 3), + "clarification_rounds": state.get("clarification_rounds", 0), + "resources": state.get("resources", []), + } + + def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict: """ Validate and fix a plan to ensure it meets requirements. @@ -217,7 +242,7 @@ def planner_node( state: State, config: RunnableConfig ) -> Command[Literal["human_feedback", "reporter"]]: """Planner node that generate the full plan.""" - logger.info("Planner generating full plan") + logger.info("Planner generating full plan with locale: %s", state.get("locale", "en-US")) configurable = Configuration.from_runnable_config(config) plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 @@ -266,7 +291,10 @@ def planner_node( # if the plan iterations is greater than the max plan iterations, return the reporter node if plan_iterations >= configurable.max_plan_iterations: - return Command(goto="reporter") + return Command( + update=preserve_state_meta_fields(state), + goto="reporter" + ) full_response = "" if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking: @@ -284,9 +312,15 @@ def planner_node( except json.JSONDecodeError: logger.warning("Planner response is not a valid JSON") if plan_iterations > 0: - return Command(goto="reporter") + return Command( + update=preserve_state_meta_fields(state), + goto="reporter" + ) else: - return Command(goto="__end__") + return Command( + update=preserve_state_meta_fields(state), + goto="__end__" + ) # Validate and fix plan to ensure web search requirements are met if isinstance(curr_plan, dict): @@ -299,6 +333,7 @@ def planner_node( update={ "messages": [AIMessage(content=full_response, name="planner")], "current_plan": new_plan, + **preserve_state_meta_fields(state), }, goto="reporter", ) @@ -306,6 +341,7 @@ def planner_node( update={ "messages": [AIMessage(content=full_response, name="planner")], "current_plan": full_response, + **preserve_state_meta_fields(state), }, goto="human_feedback", ) @@ -323,7 +359,10 @@ def human_feedback_node( # Handle None or empty feedback if not feedback: logger.warning(f"Received empty or None feedback: {feedback}. Returning to planner for new plan.") - return Command(goto="planner") + return Command( + update=preserve_state_meta_fields(state), + goto="planner" + ) # Normalize feedback string feedback_normalized = str(feedback).strip().upper() @@ -336,6 +375,7 @@ def human_feedback_node( "messages": [ HumanMessage(content=feedback, name="feedback"), ], + **preserve_state_meta_fields(state), }, goto="planner", ) @@ -343,7 +383,10 @@ def human_feedback_node( logger.info("Plan is accepted by user.") else: logger.warning(f"Unsupported feedback format: {feedback}. Please use '[ACCEPTED]' to accept or '[EDIT_PLAN]' to edit.") - return Command(goto="planner") + return Command( + update=preserve_state_meta_fields(state), + goto="planner" + ) # if the plan is accepted, run the following node plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 @@ -360,16 +403,29 @@ def human_feedback_node( except json.JSONDecodeError: logger.warning("Planner response is not a valid JSON") if plan_iterations > 1: # the plan_iterations is increased before this check - return Command(goto="reporter") + return Command( + update=preserve_state_meta_fields(state), + goto="reporter" + ) else: - return Command(goto="__end__") + return Command( + update=preserve_state_meta_fields(state), + goto="__end__" + ) + # Build update dict with safe locale handling + update_dict = { + "current_plan": Plan.model_validate(new_plan), + "plan_iterations": plan_iterations, + **preserve_state_meta_fields(state), + } + + # Only override locale if new_plan provides a valid value, otherwise use preserved locale + if new_plan.get("locale"): + update_dict["locale"] = new_plan["locale"] + return Command( - update={ - "current_plan": Plan.model_validate(new_plan), - "plan_iterations": plan_iterations, - "locale": new_plan["locale"], - }, + update=update_dict, goto=goto, ) @@ -408,6 +464,7 @@ def coordinator_node( goto = "__end__" locale = state.get("locale", "en-US") + logger.info(f"Coordinator locale: {locale}") research_topic = state.get("research_topic", "") # Process tool calls for legacy mode @@ -421,9 +478,8 @@ def coordinator_node( logger.info("Handing off to planner") goto = "planner" - # Extract locale and research_topic if provided - if tool_args.get("locale") and tool_args.get("research_topic"): - locale = tool_args.get("locale") + # Extract research_topic if provided + if tool_args.get("research_topic"): research_topic = tool_args.get("research_topic") break @@ -587,8 +643,6 @@ def coordinator_node( logger.info("Handing off to planner") goto = "planner" - # Extract locale if provided - locale = tool_args.get("locale", locale) if not enable_clarification and tool_args.get("research_topic"): research_topic = tool_args["research_topic"] @@ -725,7 +779,10 @@ async def _execute_agent_step( if not current_step: logger.warning(f"[_execute_agent_step] No unexecuted step found in {len(current_plan.steps)} total steps") - return Command(goto="research_team") + return Command( + update=preserve_state_meta_fields(state), + goto="research_team" + ) logger.info(f"[_execute_agent_step] Executing step: {current_step.title}, agent: {agent_name}") logger.debug(f"[_execute_agent_step] Completed steps so far: {len(completed_steps)}") @@ -834,6 +891,7 @@ async def _execute_agent_step( ) ], "observations": observations + [detailed_error], + **preserve_state_meta_fields(state), }, goto="research_team", ) @@ -859,6 +917,7 @@ async def _execute_agent_step( ) ], "observations": observations + [response_content], + **preserve_state_meta_fields(state), }, goto="research_team", ) diff --git a/src/server/app.py b/src/server/app.py index 9c3e545..b91bbd1 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -106,6 +106,8 @@ async def chat_stream(request: ChatRequest): # Check if MCP server configuration is enabled mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False) + logger.debug(f"get the request locale : {request.locale}") + # Validate MCP settings if provided if request.mcp_settings and not mcp_enabled: raise HTTPException( diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 857ecb1..e10e759 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -675,7 +675,7 @@ def test_coordinator_node_with_tool_calls_locale_override( tool_calls = [ { "name": "handoff_to_planner", - "args": {"locale": "zh-CN", "research_topic": "test topic"}, + "args": {"locale": "auto", "research_topic": "test topic"}, } ] with ( @@ -689,7 +689,7 @@ def test_coordinator_node_with_tool_calls_locale_override( result = coordinator_node(mock_state_coordinator, MagicMock()) assert result.goto == "planner" - assert result.update["locale"] == "zh-CN" + assert result.update["locale"] == "en-US" assert result.update["research_topic"] == "test topic" assert result.update["resources"] == ["resource1", "resource2"] assert result.update["resources"] == ["resource1", "resource2"] diff --git a/tests/unit/graph/test_agent_locale_restoration.py b/tests/unit/graph/test_agent_locale_restoration.py new file mode 100644 index 0000000..34ade1e --- /dev/null +++ b/tests/unit/graph/test_agent_locale_restoration.py @@ -0,0 +1,241 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Unit tests for agent locale restoration after create_react_agent execution. + +Tests that meta fields (especially locale) are properly restored after +agent.ainvoke() returns, since create_react_agent creates a MessagesState +subgraph that filters out custom fields. +""" + +import pytest + +from src.graph.nodes import preserve_state_meta_fields +from src.graph.types import State + + +class TestAgentLocaleRestoration: + """Test suite for locale restoration after agent execution.""" + + def test_locale_lost_in_agent_subgraph(self): + """ + Demonstrate the problem: agent subgraph filters out locale. + + When create_react_agent creates a subgraph with MessagesState, + it only returns messages, not custom fields. + """ + # Simulate agent behavior: only returns messages + initial_state = State(messages=[], locale="zh-CN") + + # Agent subgraph returns (like MessagesState would) + agent_result = { + "messages": ["agent response"], + } + + # Problem: locale is missing + assert "locale" not in agent_result + assert agent_result.get("locale") is None + + def test_locale_restoration_after_agent(self): + """Test that locale can be restored after agent.ainvoke() returns.""" + initial_state = State( + messages=[], + locale="zh-CN", + research_topic="test", + ) + + # Simulate agent returning (MessagesState only) + agent_result = { + "messages": ["agent response"], + } + + # Apply restoration + preserved = preserve_state_meta_fields(initial_state) + agent_result.update(preserved) + + # Verify restoration worked + assert agent_result["locale"] == "zh-CN" + assert agent_result["research_topic"] == "test" + assert "messages" in agent_result + + def test_all_meta_fields_restored(self): + """Test that all meta fields are restored, not just locale.""" + initial_state = State( + messages=[], + locale="en-US", + research_topic="Original Topic", + clarified_research_topic="Clarified Topic", + clarification_history=["Q1", "A1"], + enable_clarification=True, + max_clarification_rounds=5, + clarification_rounds=2, + resources=["resource1"], + ) + + # Agent result + agent_result = {"messages": ["response"]} + agent_result.update(preserve_state_meta_fields(initial_state)) + + # All fields should be restored + assert agent_result["locale"] == "en-US" + assert agent_result["research_topic"] == "Original Topic" + assert agent_result["clarified_research_topic"] == "Clarified Topic" + assert agent_result["clarification_history"] == ["Q1", "A1"] + assert agent_result["enable_clarification"] is True + assert agent_result["max_clarification_rounds"] == 5 + assert agent_result["clarification_rounds"] == 2 + assert agent_result["resources"] == ["resource1"] + + def test_locale_preservation_through_agent_cycle(self): + """Test the complete cycle: state in → agent → state out.""" + # Initial state with zh-CN locale + initial_state = State(messages=[], locale="zh-CN") + + # Step 1: Extract meta fields + preserved = preserve_state_meta_fields(initial_state) + assert preserved["locale"] == "zh-CN" + + # Step 2: Agent runs and returns only messages + agent_result = {"messages": ["agent output"]} + assert "locale" not in agent_result # Missing! + + # Step 3: Restore meta fields + agent_result.update(preserved) + + # Step 4: Verify locale is restored + assert agent_result["locale"] == "zh-CN" + + # Step 5: Create new state with restored fields + final_state = State(messages=agent_result["messages"], **preserved) + assert final_state.get("locale") == "zh-CN" + + def test_locale_not_auto_after_restoration(self): + """ + Test that locale is NOT "auto" after restoration. + + This tests the specific bug: locale was becoming "auto" + instead of the preserved "zh-CN" value. + """ + state = State(messages=[], locale="zh-CN") + + # Agent returns without locale + agent_result = {"messages": []} + + # Before fix: locale would be "auto" (default behavior) + # After fix: locale is preserved + agent_result.update(preserve_state_meta_fields(state)) + + assert agent_result.get("locale") == "zh-CN" + assert agent_result.get("locale") != "auto" + assert agent_result.get("locale") is not None + + def test_chinese_locale_preserved(self): + """Test that Chinese locale specifically is preserved.""" + locales_to_test = ["zh-CN", "zh", "zh-Hans", "zh-Hant"] + + for locale_value in locales_to_test: + state = State(messages=[], locale=locale_value) + agent_result = {"messages": []} + + agent_result.update(preserve_state_meta_fields(state)) + + assert agent_result["locale"] == locale_value, f"Failed for locale: {locale_value}" + + def test_restoration_with_new_messages(self): + """Test that restoration works even when agent adds new messages.""" + state = State(messages=[], locale="zh-CN", research_topic="research") + + # Agent processes and returns new messages + agent_result = { + "messages": ["message1", "message2", "message3"], + } + + # Restore meta fields + agent_result.update(preserve_state_meta_fields(state)) + + # Should have both new messages AND preserved meta fields + assert len(agent_result["messages"]) == 3 + assert agent_result["locale"] == "zh-CN" + assert agent_result["research_topic"] == "research" + + def test_restoration_idempotent(self): + """Test that restoring meta fields multiple times doesn't cause issues.""" + state = State(messages=[], locale="en-US") + preserved = preserve_state_meta_fields(state) + + agent_result = {"messages": []} + + # Apply restoration multiple times + agent_result.update(preserved) + agent_result.update(preserved) + agent_result.update(preserved) + + # Should still have correct locale (not corrupted) + assert agent_result["locale"] == "en-US" + assert len(agent_result) == 9 # messages + 8 meta fields + + +class TestAgentLocaleRestorationScenarios: + """Real-world scenario tests for agent locale restoration.""" + + def test_researcher_agent_preserves_locale(self): + """ + Simulate researcher agent execution preserving locale. + + Scenario: + 1. Researcher node receives state with locale="zh-CN" + 2. Calls agent.ainvoke() which returns only messages + 3. Restores locale before returning + """ + # State coming into researcher node + state = State( + messages=[], + locale="zh-CN", + research_topic="生产1公斤牛肉需要多少升水?", + ) + + # Agent executes and returns + agent_result = { + "messages": ["Researcher analysis of water usage..."], + } + + # Apply restoration (what the fix does) + agent_result.update(preserve_state_meta_fields(state)) + + # Verify for next node + assert agent_result["locale"] == "zh-CN" # ✓ Preserved! + assert agent_result.get("locale") != "auto" # ✓ Not "auto" + + def test_coder_agent_preserves_locale(self): + """Coder agent should also preserve locale.""" + state = State(messages=[], locale="en-US") + + agent_result = {"messages": ["Code generation result"]} + agent_result.update(preserve_state_meta_fields(state)) + + assert agent_result["locale"] == "en-US" + + def test_locale_persists_across_multiple_agents(self): + """Test that locale persists through multiple agent calls.""" + locales = ["zh-CN", "en-US", "fr-FR"] + + for locale in locales: + # Initial state + state = State(messages=[], locale=locale) + preserved_1 = preserve_state_meta_fields(state) + + # First agent + result_1 = {"messages": ["agent1"]} + result_1.update(preserved_1) + + # Create state for second agent + state_2 = State(messages=result_1["messages"], **preserved_1) + preserved_2 = preserve_state_meta_fields(state_2) + + # Second agent + result_2 = {"messages": result_1["messages"] + ["agent2"]} + result_2.update(preserved_2) + + # Locale should persist + assert result_2["locale"] == locale diff --git a/tests/unit/graph/test_human_feedback_locale_fix.py b/tests/unit/graph/test_human_feedback_locale_fix.py new file mode 100644 index 0000000..761d2b3 --- /dev/null +++ b/tests/unit/graph/test_human_feedback_locale_fix.py @@ -0,0 +1,316 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Unit tests for the human_feedback_node locale fix. + +Tests that the duplicate locale assignment issue is resolved: +- Locale is safely retrieved from new_plan using .get() with fallback +- If new_plan['locale'] doesn't exist, it doesn't cause a KeyError +- If new_plan['locale'] is None or empty, the preserved state locale is used +- If new_plan['locale'] has a valid value, it properly overrides the state locale +""" + +import pytest +from src.graph.nodes import preserve_state_meta_fields +from src.graph.types import State +from src.prompts.planner_model import Plan + + +class TestHumanFeedbackLocaleFixture: + """Test suite for human_feedback_node locale safe handling.""" + + def test_preserve_state_meta_fields_no_keyerror(self): + """Test that preserve_state_meta_fields never raises KeyError.""" + state = State(messages=[], locale="zh-CN") + preserved = preserve_state_meta_fields(state) + + assert preserved["locale"] == "zh-CN" + assert "locale" in preserved + + def test_new_plan_without_locale_override(self): + """ + Test scenario: Plan doesn't override locale when not provided in override dict. + + Before fix: Would set locale twice (duplicate assignment) + After fix: Uses .get() safely and only overrides if value is truthy + """ + state = State(messages=[], locale="zh-CN") + + # Simulate a plan that doesn't want to override the locale + # (locale is in the plan for validation, but not in override dict) + new_plan_dict = {"title": "Test", "thought": "Test", "steps": [], "locale": "zh-CN", "has_enough_context": False} + + # Get preserved fields + preserved = preserve_state_meta_fields(state) + + # Build update dict like the fixed code does + update_dict = { + "current_plan": Plan.model_validate(new_plan_dict), + **preserved, + } + + # Simulate a dict that doesn't have locale override (like when plan_dict is empty for override) + plan_override = {} # No locale in override dict + + # Only override locale if override dict provides a valid value + if plan_override.get("locale"): + update_dict["locale"] = plan_override["locale"] + + # The preserved locale should be used when override doesn't provide one + assert update_dict["locale"] == "zh-CN" + + def test_new_plan_with_none_locale(self): + """ + Test scenario: new_plan has locale=None. + + Before fix: Would try to set locale to None (but Plan requires it) + After fix: Uses preserved state locale since new_plan.get("locale") is falsy + """ + state = State(messages=[], locale="zh-CN") + + # new_plan with None locale (won't validate, but test the logic) + new_plan_attempt = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False} + + # Get preserved fields + preserved = preserve_state_meta_fields(state) + + # Build update dict like the fixed code does + update_dict = { + "current_plan": Plan.model_validate(new_plan_attempt), + **preserved, + } + + # Simulate checking for None locale (if it somehow got set) + new_plan_with_none = {"locale": None} + # Only override if new_plan provides a VALID value + if new_plan_with_none.get("locale"): + update_dict["locale"] = new_plan_with_none["locale"] + + # Should use preserved locale (zh-CN), not None + assert update_dict["locale"] == "zh-CN" + assert update_dict["locale"] is not None + + def test_new_plan_with_empty_string_locale(self): + """ + Test scenario: new_plan has locale="" (empty string). + + Before fix: Would try to set locale to "" (but Plan requires valid value) + After fix: Uses preserved state locale since empty string is falsy + """ + state = State(messages=[], locale="zh-CN") + + # new_plan with valid locale + new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False} + + # Get preserved fields + preserved = preserve_state_meta_fields(state) + + # Build update dict like the fixed code does + update_dict = { + "current_plan": Plan.model_validate(new_plan), + **preserved, + } + + # Simulate checking for empty string locale + new_plan_empty = {"locale": ""} + # Only override if new_plan provides a VALID (truthy) value + if new_plan_empty.get("locale"): + update_dict["locale"] = new_plan_empty["locale"] + + # Should use preserved locale (zh-CN), not empty string + assert update_dict["locale"] == "zh-CN" + assert update_dict["locale"] != "" + + def test_new_plan_with_valid_locale_overrides(self): + """ + Test scenario: new_plan has valid locale="en-US". + + Before fix: Would override with new_plan locale ✓ (worked) + After fix: Should still properly override with valid locale + """ + state = State(messages=[], locale="zh-CN") + + # new_plan has a different valid locale + new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False} + + # Get preserved fields + preserved = preserve_state_meta_fields(state) + + # Build update dict like the fixed code does + update_dict = { + "current_plan": Plan.model_validate(new_plan), + **preserved, + } + + # Override if new_plan provides a VALID value + if new_plan.get("locale"): + update_dict["locale"] = new_plan["locale"] + + # Should override with new_plan locale + assert update_dict["locale"] == "en-US" + assert update_dict["locale"] != "zh-CN" + + def test_locale_field_not_duplicated(self): + """ + Test that locale field is not duplicated in the update dict. + + Before fix: locale was set twice in the same dict + After fix: locale is only set once + """ + state = State(messages=[], locale="zh-CN") + new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False} + + preserved = preserve_state_meta_fields(state) + + # Count how many times 'locale' is set + update_dict = { + "current_plan": Plan.model_validate(new_plan), + **preserved, # Sets locale once + } + + # Override locale only if new_plan provides valid value + if new_plan.get("locale"): + update_dict["locale"] = new_plan["locale"] + + # Verify locale is in dict exactly once + locale_count = sum(1 for k in update_dict.keys() if k == "locale") + assert locale_count == 1 + assert update_dict["locale"] == "en-US" # Should be overridden + + def test_all_meta_fields_preserved(self): + """ + Test that all 8 meta fields are preserved along with locale fix. + + Ensures the fix doesn't break other meta field preservation. + """ + state = State( + messages=[], + locale="zh-CN", + research_topic="Research", + clarified_research_topic="Clarified", + clarification_history=["Q1"], + enable_clarification=True, + max_clarification_rounds=5, + clarification_rounds=1, + resources=["resource1"], + ) + + new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False} + preserved = preserve_state_meta_fields(state) + + # All 8 meta fields should be in preserved + meta_fields = [ + "locale", + "research_topic", + "clarified_research_topic", + "clarification_history", + "enable_clarification", + "max_clarification_rounds", + "clarification_rounds", + "resources", + ] + + for field in meta_fields: + assert field in preserved + + # Build update dict + update_dict = { + "current_plan": Plan.model_validate(new_plan), + **preserved, + } + + # Override locale if new_plan provides valid value + if new_plan.get("locale"): + update_dict["locale"] = new_plan["locale"] + + # All meta fields should be in update_dict + for field in meta_fields: + assert field in update_dict + + +class TestHumanFeedbackLocaleScenarios: + """Real-world scenarios for human_feedback_node locale handling.""" + + def test_scenario_chinese_locale_preserved_when_plan_has_no_locale(self): + """ + Scenario: User selected Chinese, plan preserves it. + + Expected: Preserved Chinese locale should be used + """ + state = State(messages=[], locale="zh-CN") + + # Plan from planner with required fields + new_plan_json = { + "title": "Research Plan", + "thought": "...", + "steps": [ + { + "title": "Step 1", + "description": "...", + "need_search": True, + "step_type": "research", + } + ], + "locale": "zh-CN", + "has_enough_context": False, + } + + preserved = preserve_state_meta_fields(state) + update_dict = { + "current_plan": Plan.model_validate(new_plan_json), + **preserved, + } + + if new_plan_json.get("locale"): + update_dict["locale"] = new_plan_json["locale"] + + # Chinese locale should be preserved + assert update_dict["locale"] == "zh-CN" + + def test_scenario_en_us_restored_even_if_plan_minimal(self): + """ + Scenario: Minimal plan with en-US locale. + + Expected: Preserved en-US locale should survive + """ + state = State(messages=[], locale="en-US") + + # Minimal plan with required fields + new_plan_json = {"title": "Quick Plan", "steps": [], "locale": "en-US", "has_enough_context": False} + + preserved = preserve_state_meta_fields(state) + update_dict = { + "current_plan": Plan.model_validate(new_plan_json), + **preserved, + } + + if new_plan_json.get("locale"): + update_dict["locale"] = new_plan_json["locale"] + + # en-US should survive + assert update_dict["locale"] == "en-US" + + def test_scenario_multiple_locale_updates_safe(self): + """ + Scenario: Multiple plan iterations with locale preservation. + + Expected: Each iteration safely handles locale + """ + locales = ["zh-CN", "en-US", "fr-FR"] + + for locale in locales: + state = State(messages=[], locale=locale) + new_plan = {"title": "Plan", "steps": [], "locale": locale, "has_enough_context": False} + + preserved = preserve_state_meta_fields(state) + update_dict = { + "current_plan": Plan.model_validate(new_plan), + **preserved, + } + + if new_plan.get("locale"): + update_dict["locale"] = new_plan["locale"] + + # Each iteration should preserve its locale + assert update_dict["locale"] == locale diff --git a/tests/unit/graph/test_state_preservation.py b/tests/unit/graph/test_state_preservation.py new file mode 100644 index 0000000..338d631 --- /dev/null +++ b/tests/unit/graph/test_state_preservation.py @@ -0,0 +1,355 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Unit tests for state preservation functionality in graph nodes. + +Tests the preserve_state_meta_fields() function and verifies that +critical state fields (especially locale) are properly preserved +across node state transitions. +""" + +import pytest +from langgraph.types import Command + +from src.graph.nodes import preserve_state_meta_fields +from src.graph.types import State + + +class TestPreserveStateMetaFields: + """Test suite for preserve_state_meta_fields() function.""" + + def test_preserve_all_fields_with_defaults(self): + """Test that all fields are preserved with default values when state is empty.""" + # Create a minimal state with only messages + state = State(messages=[]) + + # Extract meta fields + preserved = preserve_state_meta_fields(state) + + # Verify all expected fields are present + assert "locale" in preserved + assert "research_topic" in preserved + assert "clarified_research_topic" in preserved + assert "clarification_history" in preserved + assert "enable_clarification" in preserved + assert "max_clarification_rounds" in preserved + assert "clarification_rounds" in preserved + assert "resources" in preserved + + # Verify default values + assert preserved["locale"] == "en-US" + assert preserved["research_topic"] == "" + assert preserved["clarified_research_topic"] == "" + assert preserved["clarification_history"] == [] + assert preserved["enable_clarification"] is False + assert preserved["max_clarification_rounds"] == 3 + assert preserved["clarification_rounds"] == 0 + assert preserved["resources"] == [] + + def test_preserve_locale_from_state(self): + """Test that locale is correctly preserved when set in state.""" + state = State(messages=[], locale="zh-CN") + preserved = preserve_state_meta_fields(state) + + assert preserved["locale"] == "zh-CN" + + def test_preserve_locale_english(self): + """Test that English locale is preserved.""" + state = State(messages=[], locale="en-US") + preserved = preserve_state_meta_fields(state) + + assert preserved["locale"] == "en-US" + + def test_preserve_locale_with_custom_value(self): + """Test that custom locale values are preserved.""" + state = State(messages=[], locale="fr-FR") + preserved = preserve_state_meta_fields(state) + + assert preserved["locale"] == "fr-FR" + + def test_preserve_research_topic(self): + """Test that research_topic is correctly preserved.""" + test_topic = "How to build sustainable cities" + state = State(messages=[], research_topic=test_topic) + preserved = preserve_state_meta_fields(state) + + assert preserved["research_topic"] == test_topic + + def test_preserve_clarified_research_topic(self): + """Test that clarified_research_topic is correctly preserved.""" + test_topic = "Sustainable urban development with focus on green spaces" + state = State(messages=[], clarified_research_topic=test_topic) + preserved = preserve_state_meta_fields(state) + + assert preserved["clarified_research_topic"] == test_topic + + def test_preserve_clarification_history(self): + """Test that clarification_history is correctly preserved.""" + history = ["Q: What aspects?", "A: Architecture and planning"] + state = State(messages=[], clarification_history=history) + preserved = preserve_state_meta_fields(state) + + assert preserved["clarification_history"] == history + + def test_preserve_clarification_flags(self): + """Test that clarification flags are correctly preserved.""" + state = State( + messages=[], + enable_clarification=True, + max_clarification_rounds=5, + clarification_rounds=2, + ) + preserved = preserve_state_meta_fields(state) + + assert preserved["enable_clarification"] is True + assert preserved["max_clarification_rounds"] == 5 + assert preserved["clarification_rounds"] == 2 + + def test_preserve_resources(self): + """Test that resources list is correctly preserved.""" + resources = [{"id": "1", "name": "Resource 1"}] + state = State(messages=[], resources=resources) + preserved = preserve_state_meta_fields(state) + + assert preserved["resources"] == resources + + def test_preserve_all_fields_together(self): + """Test that all meta fields are preserved together correctly.""" + state = State( + messages=[], + locale="zh-CN", + research_topic="原始查询", + clarified_research_topic="澄清后的查询", + clarification_history=["Q1", "A1", "Q2", "A2"], + enable_clarification=True, + max_clarification_rounds=5, + clarification_rounds=2, + resources=["resource1"], + ) + + preserved = preserve_state_meta_fields(state) + + assert preserved["locale"] == "zh-CN" + assert preserved["research_topic"] == "原始查询" + assert preserved["clarified_research_topic"] == "澄清后的查询" + assert preserved["clarification_history"] == ["Q1", "A1", "Q2", "A2"] + assert preserved["enable_clarification"] is True + assert preserved["max_clarification_rounds"] == 5 + assert preserved["clarification_rounds"] == 2 + assert preserved["resources"] == ["resource1"] + + def test_preserve_returns_dict_not_state_object(self): + """Test that preserve_state_meta_fields returns a dict.""" + state = State(messages=[], locale="zh-CN") + preserved = preserve_state_meta_fields(state) + + assert isinstance(preserved, dict) + # Verify it's a plain dict with expected keys + assert "locale" in preserved + assert "research_topic" in preserved + + def test_preserve_does_not_mutate_original_state(self): + """Test that calling preserve_state_meta_fields does not mutate the original state.""" + original_locale = "zh-CN" + state = State(messages=[], locale=original_locale) + original_state_copy = dict(state) + + preserve_state_meta_fields(state) + + # Verify state hasn't changed + assert state["locale"] == original_locale + assert dict(state) == original_state_copy + + def test_preserve_with_none_values(self): + """Test that preserve handles None values gracefully.""" + state = State(messages=[], locale=None) + preserved = preserve_state_meta_fields(state) + + # Should use default when value is None + assert preserved["locale"] is None or preserved["locale"] == "en-US" + + def test_preserve_empty_lists_preserved(self): + """Test that empty lists are preserved correctly.""" + state = State( + messages=[], clarification_history=[], resources=[] + ) + preserved = preserve_state_meta_fields(state) + + assert preserved["clarification_history"] == [] + assert preserved["resources"] == [] + + def test_preserve_count_of_fields(self): + """Test that exactly 8 fields are preserved.""" + state = State(messages=[]) + preserved = preserve_state_meta_fields(state) + + # Should have exactly 8 meta fields + assert len(preserved) == 8 + + def test_preserve_field_names(self): + """Test that all expected field names are present.""" + state = State(messages=[]) + preserved = preserve_state_meta_fields(state) + + expected_fields = { + "locale", + "research_topic", + "clarified_research_topic", + "clarification_history", + "enable_clarification", + "max_clarification_rounds", + "clarification_rounds", + "resources", + } + + assert set(preserved.keys()) == expected_fields + + +class TestStatePreservationInCommand: + """Test suite for using preserved state fields in Command objects.""" + + def test_command_update_with_preserved_fields(self): + """Test that preserved fields can be unpacked into Command.update.""" + state = State(messages=[], locale="zh-CN", research_topic="测试") + + # This should not raise any errors + preserved = preserve_state_meta_fields(state) + command_update = { + "messages": [], + **preserved, + } + + assert "locale" in command_update + assert "research_topic" in command_update + assert command_update["locale"] == "zh-CN" + + def test_command_unpacking_syntax(self): + """Test that the unpacking syntax works correctly with preserved fields.""" + state = State(messages=[], locale="en-US") + preserved = preserve_state_meta_fields(state) + + # Simulate how it's used in actual nodes + update_dict = { + "messages": [], + "current_plan": None, + **preserved, + "locale": "zh-CN", + } + + assert len(update_dict) >= 10 # 2 explicit + 8 preserved + assert update_dict["locale"] == "zh-CN" # overridden value + + +class TestLocalePreservationSpecific: + """Specific test cases for locale preservation (the main issue being fixed).""" + + def test_locale_not_lost_in_transition(self): + """Test that locale is not lost when transitioning between nodes.""" + # Initial state from frontend with Chinese locale + initial_state = State(messages=[], locale="zh-CN") + + # Extract for first node transition + preserved_1 = preserve_state_meta_fields(initial_state) + + # Simulate state update from first node + updated_state_1 = State( + messages=[], **preserved_1 + ) + + # Extract for second node transition + preserved_2 = preserve_state_meta_fields(updated_state_1) + + # Locale should still be zh-CN after two transitions + assert preserved_2["locale"] == "zh-CN" + + def test_locale_chain_through_multiple_nodes(self): + """Test that locale survives through multiple node transitions.""" + initial_locale = "zh-CN" + state = State(messages=[], locale=initial_locale) + + # Simulate 5 node transitions + for _ in range(5): + preserved = preserve_state_meta_fields(state) + assert preserved["locale"] == initial_locale + + # Create new state for next "node" + state = State(messages=[], **preserved) + + # After 5 transitions, locale should still be preserved + assert state.get("locale") == initial_locale + + def test_locale_with_other_fields_preserved_together(self): + """Test that locale is preserved correctly even when other fields change.""" + initial_state = State( + messages=[], + locale="zh-CN", + research_topic="Original", + clarification_rounds=0, + ) + + preserved = preserve_state_meta_fields(initial_state) + + # Verify locale is in preserved dict + assert preserved["locale"] == "zh-CN" + assert preserved["research_topic"] == "Original" + assert preserved["clarification_rounds"] == 0 + + # Create new state with preserved fields + modified_state = State( + messages=[], + **preserved, + ) + + # Locale should be preserved + assert modified_state.get("locale") == "zh-CN" + # Research topic should be preserved from original + assert modified_state.get("research_topic") == "Original" + assert modified_state.get("clarification_rounds") == 0 + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_very_long_research_topic(self): + """Test preservation with very long research_topic.""" + long_topic = "a" * 10000 + state = State(messages=[], research_topic=long_topic) + preserved = preserve_state_meta_fields(state) + + assert preserved["research_topic"] == long_topic + + def test_unicode_characters_in_topic(self): + """Test preservation with unicode characters.""" + unicode_topic = "中文测试 🌍 テスト 🧪" + state = State(messages=[], research_topic=unicode_topic) + preserved = preserve_state_meta_fields(state) + + assert preserved["research_topic"] == unicode_topic + + def test_special_characters_in_locale(self): + """Test preservation with special locale formats.""" + special_locales = ["zh-CN", "en-US", "pt-BR", "es-ES", "ja-JP"] + + for locale in special_locales: + state = State(messages=[], locale=locale) + preserved = preserve_state_meta_fields(state) + assert preserved["locale"] == locale + + def test_large_clarification_history(self): + """Test preservation with large clarification_history.""" + large_history = [f"Q{i}: Question {i}" for i in range(100)] + state = State(messages=[], clarification_history=large_history) + preserved = preserve_state_meta_fields(state) + + assert len(preserved["clarification_history"]) == 100 + assert preserved["clarification_history"] == large_history + + def test_max_clarification_rounds_boundary(self): + """Test preservation with boundary values for max_clarification_rounds.""" + test_cases = [0, 1, 3, 10, 100, 999] + + for value in test_cases: + state = State(messages=[], max_clarification_rounds=value) + preserved = preserve_state_meta_fields(state) + assert preserved["max_clarification_rounds"] == value