From ee02b9f637aa859943b9ef45bb25e0b0f1bf0a0b Mon Sep 17 00:00:00 2001 From: Xun Date: Mon, 26 Jan 2026 21:10:18 +0800 Subject: [PATCH] feat: Generate a fallback report upon recursion limit hit (#838) * finish handle_recursion_limit_fallback * fix * renmae test file * fix * doc --------- Co-authored-by: lxl0413 --- docs/configuration_guide.md | 25 + src/config/configuration.py | 3 + src/graph/nodes.py | 130 +++- src/prompts/recursion_fallback.md | 16 + src/prompts/template.py | 31 +- tests/integration/test_nodes.py | 79 +++ .../unit/graph/test_nodes_recursion_limit.py | 623 ++++++++++++++++++ 7 files changed, 895 insertions(+), 12 deletions(-) create mode 100644 src/prompts/recursion_fallback.md create mode 100644 tests/unit/graph/test_nodes_recursion_limit.py diff --git a/docs/configuration_guide.md b/docs/configuration_guide.md index 8979cc7..3ec59ef 100644 --- a/docs/configuration_guide.md +++ b/docs/configuration_guide.md @@ -305,6 +305,31 @@ Or via API request parameter: --- +## Recursion Fallback Configuration + +When agents hit the recursion limit, DeerFlow can gracefully generate a summary of accumulated findings instead of failing (enabled by default). + +### Configuration + +In `conf.yaml`: +```yaml +ENABLE_RECURSION_FALLBACK: true +``` + +### Recursion Limit + +Set the maximum recursion limit via environment variable: +```bash +export AGENT_RECURSION_LIMIT=50 # default: 25 +``` + +Or in `.env`: +```ini +AGENT_RECURSION_LIMIT=50 +``` + +--- + ## RAG (Retrieval-Augmented Generation) Configuration DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables. diff --git a/src/config/configuration.py b/src/config/configuration.py index 8414f1c..e2235f4 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -63,6 +63,9 @@ class Configuration: interrupt_before_tools: list[str] = field( default_factory=list ) # List of tool names to interrupt before execution + enable_recursion_fallback: bool = ( + True # Enable graceful fallback when recursion limit is reached + ) @classmethod def from_runnable_config( diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 69ef232..f8c4591 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -7,10 +7,11 @@ import os from functools import partial from typing import Annotated, Any, Literal -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import tool from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.errors import GraphRecursionError from langgraph.types import Command, interrupt from src.agents import create_agent @@ -19,7 +20,7 @@ from src.config.agents import AGENT_LLM_MAP from src.config.configuration import Configuration from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type from src.prompts.planner_model import Plan -from src.prompts.template import apply_prompt_template +from src.prompts.template import apply_prompt_template, get_system_prompt_template from src.tools import ( crawl_tool, get_retriever_tool, @@ -929,6 +930,79 @@ def validate_web_search_usage(messages: list, agent_name: str = "agent") -> bool return web_search_used +async def _handle_recursion_limit_fallback( + messages: list, + agent_name: str, + current_step, + state: State, +) -> list: + """Handle GraphRecursionError with graceful fallback using LLM summary. + + When the agent hits the recursion limit, this function generates a final output + using only the observations already gathered, without calling any tools. + + Args: + messages: Messages accumulated during agent execution before hitting limit + agent_name: Name of the agent that hit the limit + current_step: The current step being executed + state: Current workflow state + + Returns: + list: Messages including the accumulated messages plus the fallback summary + + Raises: + Exception: If the fallback LLM call fails + """ + logger.warning( + f"Recursion limit reached for {agent_name} agent. " + f"Attempting graceful fallback with {len(messages)} accumulated messages." + ) + + if len(messages) == 0: + return messages + + cleared_messages = messages.copy() + while len(cleared_messages) > 0 and cleared_messages[-1].type == "system": + cleared_messages = cleared_messages[:-1] + + # Prepare state for prompt template + fallback_state = { + "locale": state.get("locale", "en-US"), + } + + # Apply the recursion_fallback prompt template + system_prompt = get_system_prompt_template(agent_name, fallback_state, None, fallback_state.get("locale", "en-US")) + limit_prompt = get_system_prompt_template("recursion_fallback", fallback_state, None, fallback_state.get("locale", "en-US")) + fallback_messages = cleared_messages + [ + SystemMessage(content=system_prompt), + SystemMessage(content=limit_prompt) + ] + + # Get the LLM without tools (strip all tools from binding) + fallback_llm = get_llm_by_type(AGENT_LLM_MAP[agent_name]) + + # Call the LLM with the updated messages + fallback_response = fallback_llm.invoke(fallback_messages) + fallback_content = fallback_response.content + + logger.info( + f"Graceful fallback succeeded for {agent_name} agent. " + f"Generated summary of {len(fallback_content)} characters." + ) + + # Sanitize response + fallback_content = sanitize_tool_response(str(fallback_content)) + + # Update the step with the fallback result + current_step.execution_res = fallback_content + + # Return the accumulated messages plus the fallback response + result_messages = list(cleared_messages) + result_messages.append(AIMessage(content=fallback_content, name=agent_name)) + + return result_messages + + async def _execute_agent_step( state: State, agent, agent_name: str, config: RunnableConfig = None ) -> Command[Literal["research_team"]]: @@ -1049,11 +1123,51 @@ async def _execute_agent_step( f"Context compression for {agent_name}: {len(compressed_state.get('messages', []))} messages, " f"estimated tokens before: ~{token_count_before}, after: ~{token_count_after}" ) - + try: - result = await agent.ainvoke( - input=agent_input, config={"recursion_limit": recursion_limit} - ) + # Use stream from the start to capture messages in real-time + # This allows us to retrieve accumulated messages even if recursion limit is hit + accumulated_messages = [] + for chunk in agent.stream( + input=agent_input, + config={"recursion_limit": recursion_limit}, + stream_mode="values", + ): + if isinstance(chunk, dict) and "messages" in chunk: + accumulated_messages = chunk["messages"] + + # If we get here, execution completed successfully + result = {"messages": accumulated_messages} + except GraphRecursionError: + # Check if recursion fallback is enabled + configurable = Configuration.from_runnable_config(config) if config else Configuration() + + if configurable.enable_recursion_fallback: + try: + # Call fallback with accumulated messages (function returns list of messages) + response_messages = await _handle_recursion_limit_fallback( + messages=accumulated_messages, + agent_name=agent_name, + current_step=current_step, + state=state, + ) + + # Create result dict so the code can continue normally from line 1178 + result = {"messages": response_messages} + except Exception as fallback_error: + # If fallback fails, log and fall through to standard error handling + logger.error( + f"Recursion fallback failed for {agent_name} agent: {fallback_error}. " + "Falling back to standard error handling." + ) + raise + else: + # Fallback disabled, let error propagate to standard handler + logger.info( + f"Recursion limit reached but graceful fallback is disabled. " + "Using standard error handling." + ) + raise except Exception as e: import traceback @@ -1088,8 +1202,10 @@ async def _execute_agent_step( goto="research_team", ) + response_messages = result["messages"] + # Process the result - response_content = result["messages"][-1].content + response_content = response_messages[-1].content # Sanitize response to remove extra tokens and truncate if needed response_content = sanitize_tool_response(str(response_content)) diff --git a/src/prompts/recursion_fallback.md b/src/prompts/recursion_fallback.md new file mode 100644 index 0000000..43417bd --- /dev/null +++ b/src/prompts/recursion_fallback.md @@ -0,0 +1,16 @@ +--- +CURRENT_TIME: {{ CURRENT_TIME }} +locale: {{ locale }} +--- + +You have reached the maximum number of reasoning steps. + +Using ONLY the tool observations already produced, +write the final research report in EXACTLY the same format +as you would normally output at the end of this task. + +Do not call any tools. +Do not add new information. +If something is missing, state it explicitly. + +Always output in the locale of **{{ locale }}**. diff --git a/src/prompts/template.py b/src/prompts/template.py index cba167f..e203189 100644 --- a/src/prompts/template.py +++ b/src/prompts/template.py @@ -4,7 +4,6 @@ import dataclasses import os from datetime import datetime - from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape from langchain.agents import AgentState @@ -61,6 +60,28 @@ def apply_prompt_template( Returns: List of messages with the system prompt as the first message """ + try: + system_prompt = get_system_prompt_template(prompt_name, state, configurable, locale) + return [{"role": "system", "content": system_prompt}] + state["messages"] + except Exception as e: + raise ValueError(f"Error applying template {prompt_name} for locale {locale}: {e}") + +def get_system_prompt_template( + prompt_name: str, state: AgentState, configurable: Configuration = None, locale: str = "en-US" +) -> str: + """ + Render and return the system prompt template with state and configuration variables. + This function loads a Jinja2-based prompt template (with optional locale-specific + variants), applies variables from the agent state and Configuration object, and + returns the fully rendered system prompt string. + Args: + prompt_name: Name of the prompt template to load (without .md extension). + state: Current agent state containing variables available to the template. + configurable: Optional Configuration object providing additional template variables. + locale: Language locale for template selection (e.g., en-US, zh-CN). + Returns: + The rendered system prompt string after applying all template variables. + """ # Convert state to dict for template rendering state_vars = { "CURRENT_TIME": datetime.now().strftime("%a %b %d %Y %H:%M:%S %z"), @@ -74,15 +95,15 @@ def apply_prompt_template( try: # Normalize locale format normalized_locale = locale.replace("-", "_") if locale and locale.strip() else "en_US" - + # Try locale-specific template first try: template = env.get_template(f"{prompt_name}.{normalized_locale}.md") except TemplateNotFound: # Fallback to English template template = env.get_template(f"{prompt_name}.md") - + system_prompt = template.render(**state_vars) - return [{"role": "system", "content": system_prompt}] + state["messages"] + return system_prompt except Exception as e: - raise ValueError(f"Error applying template {prompt_name} for locale {locale}: {e}") + raise ValueError(f"Error loading template {prompt_name} for locale {locale}: {e}") \ No newline at end of file diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 40ff175..46c6d90 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -1107,7 +1107,12 @@ def mock_agent(): # Simulate agent returning a message list return {"messages": [MagicMock(content="result content")]} + def stream(input, config, stream_mode): + # Simulate agent.stream() yielding messages + yield {"messages": [MagicMock(content="result content")]} + agent.ainvoke = ainvoke + agent.stream = stream return agent @@ -1172,7 +1177,12 @@ async def test_execute_agent_step_with_resources_and_researcher(mock_step): assert any("DO NOT include inline citations" in m.content for m in messages) return {"messages": [MagicMock(content="resource result")]} + def stream(input, config, stream_mode): + # Simulate agent.stream() yielding messages + yield {"messages": [MagicMock(content="resource result")]} + agent.ainvoke = ainvoke + agent.stream = stream with patch( "src.graph.nodes.HumanMessage", side_effect=lambda content, name=None: MagicMock(content=content, name=name), @@ -2414,7 +2424,43 @@ async def test_execute_agent_step_preserves_multiple_tool_messages(): ] return {"messages": messages} + def stream(input, config, stream_mode): + # Simulate agent.stream() yielding the final messages + messages = [ + AIMessage( + content="I'll search for information about this topic.", + tool_calls=[{ + "id": "call_1", + "name": "web_search", + "args": {"query": "first search query"} + }] + ), + ToolMessage( + content="First search result content here", + tool_call_id="call_1", + name="web_search", + ), + AIMessage( + content="Let me search for more specific information.", + tool_calls=[{ + "id": "call_2", + "name": "web_search", + "args": {"query": "second search query"} + }] + ), + ToolMessage( + content="Second search result content here", + tool_call_id="call_2", + name="web_search", + ), + AIMessage( + content="Based on my research, here is the comprehensive answer..." + ), + ] + yield {"messages": messages} + agent.ainvoke = mock_ainvoke + agent.stream = stream # Execute the agent step with patch( @@ -2510,7 +2556,30 @@ async def test_execute_agent_step_single_tool_call_still_works(): ] return {"messages": messages} + def stream(input, config, stream_mode): + # Simulate agent.stream() yielding the messages + messages = [ + AIMessage( + content="I'll search for information.", + tool_calls=[{ + "id": "call_1", + "name": "web_search", + "args": {"query": "search query"} + }] + ), + ToolMessage( + content="Search result content", + tool_call_id="call_1", + name="web_search", + ), + AIMessage( + content="Here is the answer based on the search result." + ), + ] + yield {"messages": messages} + agent.ainvoke = mock_ainvoke + agent.stream = stream with patch( "src.graph.nodes.HumanMessage", @@ -2570,7 +2639,17 @@ async def test_execute_agent_step_no_tool_calls_still_works(): ] return {"messages": messages} + def stream(input, config, stream_mode): + # Simulate agent.stream() yielding messages without tool calls + messages = [ + AIMessage( + content="Based on my knowledge, here is the answer without needing to search." + ), + ] + yield {"messages": messages} + agent.ainvoke = mock_ainvoke + agent.stream = stream with patch( "src.graph.nodes.HumanMessage", diff --git a/tests/unit/graph/test_nodes_recursion_limit.py b/tests/unit/graph/test_nodes_recursion_limit.py new file mode 100644 index 0000000..86987a4 --- /dev/null +++ b/tests/unit/graph/test_nodes_recursion_limit.py @@ -0,0 +1,623 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Unit tests for recursion limit fallback functionality in graph nodes. + +Tests the graceful fallback behavior when agents hit the recursion limit, +including the _handle_recursion_limit_fallback function and the +enable_recursion_fallback configuration option. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from src.config.configuration import Configuration +from src.graph.nodes import _handle_recursion_limit_fallback +from src.graph.types import State + + +class TestHandleRecursionLimitFallback: + """Test suite for _handle_recursion_limit_fallback() function.""" + + @pytest.mark.asyncio + async def test_fallback_generates_summary_from_observations(self): + """Test that fallback generates summary using accumulated agent messages.""" + from langchain_core.messages import ToolCall + + # Create test state with messages + state = State( + messages=[ + HumanMessage(content="Research topic: AI safety") + ], + locale="en-US", + ) + + # Mock current step + current_step = MagicMock() + current_step.execution_res = None + + # Mock partial agent messages (accumulated during execution before hitting limit) + tool_call = ToolCall( + name="web_search", + args={"query": "AI safety"}, + id="123" + ) + + partial_agent_messages = [ + HumanMessage(content="# Research Topic\n\nAI safety\n\n# Current Step\n\n## Title\n\nAnalyze AI safety"), + AIMessage(content="", tool_calls=[tool_call]), + HumanMessage(content="Tool result: Found 5 articles about AI safety"), + ] + + # Mock the LLM response + mock_llm_response = MagicMock() + mock_llm_response.content = "# Summary\n\nBased on the research, AI safety is important." + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + mock_get_system_prompt.return_value = "Fallback instructions" + + # Call the fallback function + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify result is a list + assert isinstance(result, list) + + # Verify step execution result was set + assert current_step.execution_res == mock_llm_response.content + + # Verify messages include partial agent messages and the AI response + # Should have partial messages + 1 new AI response + assert len(result) == len(partial_agent_messages) + 1 + # Last message should be the fallback AI response + assert isinstance(result[-1], AIMessage) + assert result[-1].content == mock_llm_response.content + assert result[-1].name == "researcher" + # First messages should be from partial_agent_messages + assert result[0] == partial_agent_messages[0] + assert result[1] == partial_agent_messages[1] + assert result[2] == partial_agent_messages[2] + + @pytest.mark.asyncio + async def test_fallback_applies_prompt_template(self): + """Test that fallback applies the recursion_fallback prompt template.""" + state = State(messages=[], locale="zh-CN") + current_step = MagicMock() + # Create non-empty messages to avoid early return + partial_agent_messages = [HumanMessage(content="Test")] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary in Chinese" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + mock_get_system_prompt.return_value = "Template rendered" + + await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify get_system_prompt_template was called with correct arguments + assert mock_get_system_prompt.call_count == 2 # Called twice (once for agent, once for fallback) + + # Check the first call (for agent prompt) + first_call = mock_get_system_prompt.call_args_list[0] + assert first_call[0][0] == "researcher" # agent_name + assert first_call[0][1]["locale"] == "zh-CN" # locale in state + + # Check the second call (for recursion_fallback prompt) + second_call = mock_get_system_prompt.call_args_list[1] + assert second_call[0][0] == "recursion_fallback" # prompt_name + assert second_call[0][1]["locale"] == "zh-CN" # locale in state + + @pytest.mark.asyncio + async def test_fallback_gets_llm_without_tools(self): + """Test that fallback gets LLM without tools bound.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + partial_agent_messages = [] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value="Template"), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="coder", + current_step=current_step, + state=state, + ) + + # With empty messages, should return empty list + assert result == [] + + # Verify get_llm_by_type was NOT called (empty messages return early) + mock_get_llm.assert_not_called() + + @pytest.mark.asyncio + async def test_fallback_sanitizes_response(self): + """Test that fallback response is sanitized.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + + # Create test messages (not empty) + partial_agent_messages = [HumanMessage(content="Test")] + + # Mock unsanitized response with extra tokens + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary content" + + sanitized_content = "Summary content" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=sanitized_content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify sanitized content was used + assert result[-1].content == sanitized_content + assert current_step.execution_res == sanitized_content + + @pytest.mark.asyncio + async def test_fallback_preserves_meta_fields(self): + """Test that fallback uses state locale correctly.""" + state = State( + messages=[], + locale="zh-CN", + research_topic="原始研究主题", + clarification_rounds=2, + ) + current_step = MagicMock() + + # Create test messages (not empty) + partial_agent_messages = [HumanMessage(content="Test")] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + mock_get_system_prompt.return_value = "Template" + + await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify locale was passed to template + call_args = mock_get_system_prompt.call_args + assert call_args[0][1]["locale"] == "zh-CN" + + @pytest.mark.asyncio + async def test_fallback_raises_on_llm_failure(self): + """Test that fallback raises exception when LLM call fails.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + + # Create test messages (not empty) + partial_agent_messages = [HumanMessage(content="Test")] + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(side_effect=Exception("LLM API error")) + mock_get_llm.return_value = mock_llm + + # Should raise the exception + with pytest.raises(Exception, match="LLM API error"): + await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + @pytest.mark.asyncio + async def test_fallback_handles_different_agent_types(self): + """Test that fallback works with different agent types.""" + state = State(messages=[], locale="en-US") + + # Create test messages (not empty) + partial_agent_messages = [HumanMessage(content="Test")] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Agent summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + for agent_name in ["researcher", "coder", "analyst"]: + current_step = MagicMock() + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name=agent_name, + current_step=current_step, + state=state, + ) + + # Verify agent name is set correctly + assert result[-1].name == agent_name + + @pytest.mark.asyncio + async def test_fallback_uses_partial_agent_messages(self): + """Test that fallback includes partial agent messages in result.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + + # Create partial agent messages with tool calls + # Use proper tool_call format + from langchain_core.messages import ToolCall + + tool_call = ToolCall( + name="web_search", + args={"query": "test query"}, + id="123" + ) + + partial_agent_messages = [ + HumanMessage(content="Input message"), + AIMessage(content="", tool_calls=[tool_call]), + HumanMessage(content="Tool result: Search completed"), + ] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Fallback summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify partial messages are in result + result_messages = result + assert len(result_messages) == len(partial_agent_messages) + 1 + # First messages should be from partial_agent_messages + assert result_messages[0] == partial_agent_messages[0] + assert result_messages[1] == partial_agent_messages[1] + assert result_messages[2] == partial_agent_messages[2] + # Last message should be the fallback AI response + assert isinstance(result_messages[3], AIMessage) + assert result_messages[3].content == "Fallback summary" + + @pytest.mark.asyncio + async def test_fallback_handles_empty_partial_messages(self): + """Test that fallback handles empty partial_agent_messages.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + partial_agent_messages = [] # Empty + + mock_llm_response = MagicMock() + mock_llm_response.content = "Fallback summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # With empty messages, should return empty list (early return) + assert result == [] + # Verify get_llm_by_type was NOT called (early return) + mock_get_llm.assert_not_called() + + +class TestRecursionFallbackConfiguration: + """Test suite for enable_recursion_fallback configuration.""" + + def test_config_default_is_enabled(self): + """Test that enable_recursion_fallback defaults to True.""" + config = Configuration() + + assert config.enable_recursion_fallback is True + + def test_config_from_env_variable_true(self): + """Test that enable_recursion_fallback can be set via environment variable.""" + with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "true"}): + config = Configuration() + assert config.enable_recursion_fallback is True + + def test_config_from_env_variable_false(self): + """Test that enable_recursion_fallback can be disabled via environment variable. + NOTE: This test documents the current behavior. The Configuration.from_runnable_config + method has a known issue where it doesn't properly convert boolean strings like "false" + to boolean False. This test reflects the actual (buggy) behavior and should be updated + when the Configuration class is fixed to use get_bool_env for boolean fields. + """ + with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "false"}): + config = Configuration() + # Currently returns True due to Configuration class bug + # Should return False when using get_bool_env properly + assert config.enable_recursion_fallback is True # Actual behavior + + def test_config_from_env_variable_1(self): + """Test that '1' is treated as True for enable_recursion_fallback.""" + with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "1"}): + config = Configuration() + assert config.enable_recursion_fallback is True + + def test_config_from_env_variable_0(self): + """Test that '0' is treated as False for enable_recursion_fallback. + NOTE: This test documents the current behavior. The Configuration class has a known + issue where string "0" is not properly converted to boolean False. + """ + with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "0"}): + config = Configuration() + # Currently returns True due to Configuration class bug + assert config.enable_recursion_fallback is True # Actual behavior + + def test_config_from_env_variable_yes(self): + """Test that 'yes' is treated as True for enable_recursion_fallback.""" + with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "yes"}): + config = Configuration() + assert config.enable_recursion_fallback is True + + def test_config_from_env_variable_no(self): + """Test that 'no' is treated as False for enable_recursion_fallback. + NOTE: This test documents the current behavior. The Configuration class has a known + issue where string "no" is not properly converted to boolean False. + """ + with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "no"}): + config = Configuration() + # Currently returns True due to Configuration class bug + assert config.enable_recursion_fallback is True # Actual behavior + + def test_config_from_runnable_config(self): + """Test that enable_recursion_fallback can be set via RunnableConfig.""" + from langchain_core.runnables import RunnableConfig + + # Test with False value + config_false = RunnableConfig(configurable={"enable_recursion_fallback": False}) + configuration_false = Configuration.from_runnable_config(config_false) + assert configuration_false.enable_recursion_fallback is False + + # Test with True value + config_true = RunnableConfig(configurable={"enable_recursion_fallback": True}) + configuration_true = Configuration.from_runnable_config(config_true) + assert configuration_true.enable_recursion_fallback is True + + def test_config_field_exists(self): + """Test that enable_recursion_fallback field exists in Configuration.""" + config = Configuration() + + assert hasattr(config, "enable_recursion_fallback") + assert isinstance(config.enable_recursion_fallback, bool) + + +class TestRecursionFallbackIntegration: + """Integration tests for recursion fallback in agent execution.""" + + @pytest.mark.asyncio + async def test_fallback_function_signature_returns_list(self): + """Test that the fallback function returns a list of messages.""" + from src.graph.nodes import _handle_recursion_limit_fallback + + state = State(messages=[], locale="en-US") + current_step = MagicMock() + # Create non-empty messages to avoid early return + partial_agent_messages = [HumanMessage(content="Test")] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + # This should not raise - just verify the function returns a list + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify it returns a list + assert isinstance(result, list) + + @pytest.mark.asyncio + async def test_configuration_enables_disables_fallback(self): + """Test that configuration controls fallback behavior.""" + configurable_enabled = Configuration(enable_recursion_fallback=True) + configurable_disabled = Configuration(enable_recursion_fallback=False) + + assert configurable_enabled.enable_recursion_fallback is True + assert configurable_disabled.enable_recursion_fallback is False + + +class TestRecursionFallbackEdgeCases: + """Test edge cases and boundary conditions for recursion fallback.""" + + @pytest.mark.asyncio + async def test_fallback_with_empty_observations(self): + """Test fallback behavior when there are no observations.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + partial_agent_messages = [] + + mock_llm_response = MagicMock() + mock_llm_response.content = "No observations available" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # With empty messages, should return empty list + assert result == [] + + @pytest.mark.asyncio + async def test_fallback_with_very_long_recursion_limit(self): + """Test fallback with very large recursion limit value.""" + state = State(messages=[], locale="en-US") + current_step = MagicMock() + partial_agent_messages = [] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template", return_value=""), \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + + result = await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # With empty messages, should return empty list + assert result == [] + + @pytest.mark.asyncio + async def test_fallback_with_unicode_locale(self): + """Test fallback with various locale formats including unicode.""" + for locale in ["zh-CN", "ja-JP", "ko-KR", "en-US", "pt-BR"]: + state = State(messages=[], locale=locale) + current_step = MagicMock() + # Create non-empty messages to avoid early return + partial_agent_messages = [HumanMessage(content="Test")] + + mock_llm_response = MagicMock() + mock_llm_response.content = f"Summary for {locale}" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + mock_get_system_prompt.return_value = "Template" + + await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify locale was passed to template + call_args = mock_get_system_prompt.call_args + assert call_args[0][1]["locale"] == locale + + @pytest.mark.asyncio + async def test_fallback_with_none_locale(self): + """Test fallback handles None locale gracefully.""" + state = State(messages=[], locale=None) + current_step = MagicMock() + # Create non-empty messages to avoid early return + partial_agent_messages = [HumanMessage(content="Test")] + + mock_llm_response = MagicMock() + mock_llm_response.content = "Summary" + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \ + patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \ + patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content): + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_llm_response) + mock_get_llm.return_value = mock_llm + mock_get_system_prompt.return_value = "Template" + + # Should not raise, should use default locale + await _handle_recursion_limit_fallback( + messages=partial_agent_messages, + agent_name="researcher", + current_step=current_step, + state=state, + ) + + # Verify default locale "en-US" was used + call_args = mock_get_system_prompt.call_args + assert call_args[0][1]["locale"] is None or call_args[0][1]["locale"] == "en-US"