feat: Generate a fallback report upon recursion limit hit (#838)

* finish handle_recursion_limit_fallback

* fix

* renmae test file

* fix

* doc

---------

Co-authored-by: lxl0413 <lixinling2021@gmail.com>
This commit is contained in:
Xun
2026-01-26 21:10:18 +08:00
committed by GitHub
parent 9a34e32252
commit ee02b9f637
7 changed files with 895 additions and 12 deletions

View File

@@ -305,6 +305,31 @@ Or via API request parameter:
---
## Recursion Fallback Configuration
When agents hit the recursion limit, DeerFlow can gracefully generate a summary of accumulated findings instead of failing (enabled by default).
### Configuration
In `conf.yaml`:
```yaml
ENABLE_RECURSION_FALLBACK: true
```
### Recursion Limit
Set the maximum recursion limit via environment variable:
```bash
export AGENT_RECURSION_LIMIT=50 # default: 25
```
Or in `.env`:
```ini
AGENT_RECURSION_LIMIT=50
```
---
## RAG (Retrieval-Augmented Generation) Configuration
DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables.

View File

@@ -63,6 +63,9 @@ class Configuration:
interrupt_before_tools: list[str] = field(
default_factory=list
) # List of tool names to interrupt before execution
enable_recursion_fallback: bool = (
True # Enable graceful fallback when recursion limit is reached
)
@classmethod
def from_runnable_config(

View File

@@ -7,10 +7,11 @@ import os
from functools import partial
from typing import Annotated, Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.errors import GraphRecursionError
from langgraph.types import Command, interrupt
from src.agents import create_agent
@@ -19,7 +20,7 @@ from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template
from src.prompts.template import apply_prompt_template, get_system_prompt_template
from src.tools import (
crawl_tool,
get_retriever_tool,
@@ -929,6 +930,79 @@ def validate_web_search_usage(messages: list, agent_name: str = "agent") -> bool
return web_search_used
async def _handle_recursion_limit_fallback(
messages: list,
agent_name: str,
current_step,
state: State,
) -> list:
"""Handle GraphRecursionError with graceful fallback using LLM summary.
When the agent hits the recursion limit, this function generates a final output
using only the observations already gathered, without calling any tools.
Args:
messages: Messages accumulated during agent execution before hitting limit
agent_name: Name of the agent that hit the limit
current_step: The current step being executed
state: Current workflow state
Returns:
list: Messages including the accumulated messages plus the fallback summary
Raises:
Exception: If the fallback LLM call fails
"""
logger.warning(
f"Recursion limit reached for {agent_name} agent. "
f"Attempting graceful fallback with {len(messages)} accumulated messages."
)
if len(messages) == 0:
return messages
cleared_messages = messages.copy()
while len(cleared_messages) > 0 and cleared_messages[-1].type == "system":
cleared_messages = cleared_messages[:-1]
# Prepare state for prompt template
fallback_state = {
"locale": state.get("locale", "en-US"),
}
# Apply the recursion_fallback prompt template
system_prompt = get_system_prompt_template(agent_name, fallback_state, None, fallback_state.get("locale", "en-US"))
limit_prompt = get_system_prompt_template("recursion_fallback", fallback_state, None, fallback_state.get("locale", "en-US"))
fallback_messages = cleared_messages + [
SystemMessage(content=system_prompt),
SystemMessage(content=limit_prompt)
]
# Get the LLM without tools (strip all tools from binding)
fallback_llm = get_llm_by_type(AGENT_LLM_MAP[agent_name])
# Call the LLM with the updated messages
fallback_response = fallback_llm.invoke(fallback_messages)
fallback_content = fallback_response.content
logger.info(
f"Graceful fallback succeeded for {agent_name} agent. "
f"Generated summary of {len(fallback_content)} characters."
)
# Sanitize response
fallback_content = sanitize_tool_response(str(fallback_content))
# Update the step with the fallback result
current_step.execution_res = fallback_content
# Return the accumulated messages plus the fallback response
result_messages = list(cleared_messages)
result_messages.append(AIMessage(content=fallback_content, name=agent_name))
return result_messages
async def _execute_agent_step(
state: State, agent, agent_name: str, config: RunnableConfig = None
) -> Command[Literal["research_team"]]:
@@ -1049,11 +1123,51 @@ async def _execute_agent_step(
f"Context compression for {agent_name}: {len(compressed_state.get('messages', []))} messages, "
f"estimated tokens before: ~{token_count_before}, after: ~{token_count_after}"
)
try:
result = await agent.ainvoke(
input=agent_input, config={"recursion_limit": recursion_limit}
)
# Use stream from the start to capture messages in real-time
# This allows us to retrieve accumulated messages even if recursion limit is hit
accumulated_messages = []
for chunk in agent.stream(
input=agent_input,
config={"recursion_limit": recursion_limit},
stream_mode="values",
):
if isinstance(chunk, dict) and "messages" in chunk:
accumulated_messages = chunk["messages"]
# If we get here, execution completed successfully
result = {"messages": accumulated_messages}
except GraphRecursionError:
# Check if recursion fallback is enabled
configurable = Configuration.from_runnable_config(config) if config else Configuration()
if configurable.enable_recursion_fallback:
try:
# Call fallback with accumulated messages (function returns list of messages)
response_messages = await _handle_recursion_limit_fallback(
messages=accumulated_messages,
agent_name=agent_name,
current_step=current_step,
state=state,
)
# Create result dict so the code can continue normally from line 1178
result = {"messages": response_messages}
except Exception as fallback_error:
# If fallback fails, log and fall through to standard error handling
logger.error(
f"Recursion fallback failed for {agent_name} agent: {fallback_error}. "
"Falling back to standard error handling."
)
raise
else:
# Fallback disabled, let error propagate to standard handler
logger.info(
f"Recursion limit reached but graceful fallback is disabled. "
"Using standard error handling."
)
raise
except Exception as e:
import traceback
@@ -1088,8 +1202,10 @@ async def _execute_agent_step(
goto="research_team",
)
response_messages = result["messages"]
# Process the result
response_content = result["messages"][-1].content
response_content = response_messages[-1].content
# Sanitize response to remove extra tokens and truncate if needed
response_content = sanitize_tool_response(str(response_content))

View File

@@ -0,0 +1,16 @@
---
CURRENT_TIME: {{ CURRENT_TIME }}
locale: {{ locale }}
---
You have reached the maximum number of reasoning steps.
Using ONLY the tool observations already produced,
write the final research report in EXACTLY the same format
as you would normally output at the end of this task.
Do not call any tools.
Do not add new information.
If something is missing, state it explicitly.
Always output in the locale of **{{ locale }}**.

View File

@@ -4,7 +4,6 @@
import dataclasses
import os
from datetime import datetime
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
from langchain.agents import AgentState
@@ -61,6 +60,28 @@ def apply_prompt_template(
Returns:
List of messages with the system prompt as the first message
"""
try:
system_prompt = get_system_prompt_template(prompt_name, state, configurable, locale)
return [{"role": "system", "content": system_prompt}] + state["messages"]
except Exception as e:
raise ValueError(f"Error applying template {prompt_name} for locale {locale}: {e}")
def get_system_prompt_template(
prompt_name: str, state: AgentState, configurable: Configuration = None, locale: str = "en-US"
) -> str:
"""
Render and return the system prompt template with state and configuration variables.
This function loads a Jinja2-based prompt template (with optional locale-specific
variants), applies variables from the agent state and Configuration object, and
returns the fully rendered system prompt string.
Args:
prompt_name: Name of the prompt template to load (without .md extension).
state: Current agent state containing variables available to the template.
configurable: Optional Configuration object providing additional template variables.
locale: Language locale for template selection (e.g., en-US, zh-CN).
Returns:
The rendered system prompt string after applying all template variables.
"""
# Convert state to dict for template rendering
state_vars = {
"CURRENT_TIME": datetime.now().strftime("%a %b %d %Y %H:%M:%S %z"),
@@ -74,15 +95,15 @@ def apply_prompt_template(
try:
# Normalize locale format
normalized_locale = locale.replace("-", "_") if locale and locale.strip() else "en_US"
# Try locale-specific template first
try:
template = env.get_template(f"{prompt_name}.{normalized_locale}.md")
except TemplateNotFound:
# Fallback to English template
template = env.get_template(f"{prompt_name}.md")
system_prompt = template.render(**state_vars)
return [{"role": "system", "content": system_prompt}] + state["messages"]
return system_prompt
except Exception as e:
raise ValueError(f"Error applying template {prompt_name} for locale {locale}: {e}")
raise ValueError(f"Error loading template {prompt_name} for locale {locale}: {e}")

View File

@@ -1107,7 +1107,12 @@ def mock_agent():
# Simulate agent returning a message list
return {"messages": [MagicMock(content="result content")]}
def stream(input, config, stream_mode):
# Simulate agent.stream() yielding messages
yield {"messages": [MagicMock(content="result content")]}
agent.ainvoke = ainvoke
agent.stream = stream
return agent
@@ -1172,7 +1177,12 @@ async def test_execute_agent_step_with_resources_and_researcher(mock_step):
assert any("DO NOT include inline citations" in m.content for m in messages)
return {"messages": [MagicMock(content="resource result")]}
def stream(input, config, stream_mode):
# Simulate agent.stream() yielding messages
yield {"messages": [MagicMock(content="resource result")]}
agent.ainvoke = ainvoke
agent.stream = stream
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
@@ -2414,7 +2424,43 @@ async def test_execute_agent_step_preserves_multiple_tool_messages():
]
return {"messages": messages}
def stream(input, config, stream_mode):
# Simulate agent.stream() yielding the final messages
messages = [
AIMessage(
content="I'll search for information about this topic.",
tool_calls=[{
"id": "call_1",
"name": "web_search",
"args": {"query": "first search query"}
}]
),
ToolMessage(
content="First search result content here",
tool_call_id="call_1",
name="web_search",
),
AIMessage(
content="Let me search for more specific information.",
tool_calls=[{
"id": "call_2",
"name": "web_search",
"args": {"query": "second search query"}
}]
),
ToolMessage(
content="Second search result content here",
tool_call_id="call_2",
name="web_search",
),
AIMessage(
content="Based on my research, here is the comprehensive answer..."
),
]
yield {"messages": messages}
agent.ainvoke = mock_ainvoke
agent.stream = stream
# Execute the agent step
with patch(
@@ -2510,7 +2556,30 @@ async def test_execute_agent_step_single_tool_call_still_works():
]
return {"messages": messages}
def stream(input, config, stream_mode):
# Simulate agent.stream() yielding the messages
messages = [
AIMessage(
content="I'll search for information.",
tool_calls=[{
"id": "call_1",
"name": "web_search",
"args": {"query": "search query"}
}]
),
ToolMessage(
content="Search result content",
tool_call_id="call_1",
name="web_search",
),
AIMessage(
content="Here is the answer based on the search result."
),
]
yield {"messages": messages}
agent.ainvoke = mock_ainvoke
agent.stream = stream
with patch(
"src.graph.nodes.HumanMessage",
@@ -2570,7 +2639,17 @@ async def test_execute_agent_step_no_tool_calls_still_works():
]
return {"messages": messages}
def stream(input, config, stream_mode):
# Simulate agent.stream() yielding messages without tool calls
messages = [
AIMessage(
content="Based on my knowledge, here is the answer without needing to search."
),
]
yield {"messages": messages}
agent.ainvoke = mock_ainvoke
agent.stream = stream
with patch(
"src.graph.nodes.HumanMessage",

View File

@@ -0,0 +1,623 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for recursion limit fallback functionality in graph nodes.
Tests the graceful fallback behavior when agents hit the recursion limit,
including the _handle_recursion_limit_fallback function and the
enable_recursion_fallback configuration option.
"""
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from src.config.configuration import Configuration
from src.graph.nodes import _handle_recursion_limit_fallback
from src.graph.types import State
class TestHandleRecursionLimitFallback:
"""Test suite for _handle_recursion_limit_fallback() function."""
@pytest.mark.asyncio
async def test_fallback_generates_summary_from_observations(self):
"""Test that fallback generates summary using accumulated agent messages."""
from langchain_core.messages import ToolCall
# Create test state with messages
state = State(
messages=[
HumanMessage(content="Research topic: AI safety")
],
locale="en-US",
)
# Mock current step
current_step = MagicMock()
current_step.execution_res = None
# Mock partial agent messages (accumulated during execution before hitting limit)
tool_call = ToolCall(
name="web_search",
args={"query": "AI safety"},
id="123"
)
partial_agent_messages = [
HumanMessage(content="# Research Topic\n\nAI safety\n\n# Current Step\n\n## Title\n\nAnalyze AI safety"),
AIMessage(content="", tool_calls=[tool_call]),
HumanMessage(content="Tool result: Found 5 articles about AI safety"),
]
# Mock the LLM response
mock_llm_response = MagicMock()
mock_llm_response.content = "# Summary\n\nBased on the research, AI safety is important."
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Fallback instructions"
# Call the fallback function
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify result is a list
assert isinstance(result, list)
# Verify step execution result was set
assert current_step.execution_res == mock_llm_response.content
# Verify messages include partial agent messages and the AI response
# Should have partial messages + 1 new AI response
assert len(result) == len(partial_agent_messages) + 1
# Last message should be the fallback AI response
assert isinstance(result[-1], AIMessage)
assert result[-1].content == mock_llm_response.content
assert result[-1].name == "researcher"
# First messages should be from partial_agent_messages
assert result[0] == partial_agent_messages[0]
assert result[1] == partial_agent_messages[1]
assert result[2] == partial_agent_messages[2]
@pytest.mark.asyncio
async def test_fallback_applies_prompt_template(self):
"""Test that fallback applies the recursion_fallback prompt template."""
state = State(messages=[], locale="zh-CN")
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary in Chinese"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template rendered"
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify get_system_prompt_template was called with correct arguments
assert mock_get_system_prompt.call_count == 2 # Called twice (once for agent, once for fallback)
# Check the first call (for agent prompt)
first_call = mock_get_system_prompt.call_args_list[0]
assert first_call[0][0] == "researcher" # agent_name
assert first_call[0][1]["locale"] == "zh-CN" # locale in state
# Check the second call (for recursion_fallback prompt)
second_call = mock_get_system_prompt.call_args_list[1]
assert second_call[0][0] == "recursion_fallback" # prompt_name
assert second_call[0][1]["locale"] == "zh-CN" # locale in state
@pytest.mark.asyncio
async def test_fallback_gets_llm_without_tools(self):
"""Test that fallback gets LLM without tools bound."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = []
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value="Template"), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="coder",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list
assert result == []
# Verify get_llm_by_type was NOT called (empty messages return early)
mock_get_llm.assert_not_called()
@pytest.mark.asyncio
async def test_fallback_sanitizes_response(self):
"""Test that fallback response is sanitized."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
# Mock unsanitized response with extra tokens
mock_llm_response = MagicMock()
mock_llm_response.content = "<extra_tokens>Summary content</extra_tokens>"
sanitized_content = "Summary content"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=sanitized_content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify sanitized content was used
assert result[-1].content == sanitized_content
assert current_step.execution_res == sanitized_content
@pytest.mark.asyncio
async def test_fallback_preserves_meta_fields(self):
"""Test that fallback uses state locale correctly."""
state = State(
messages=[],
locale="zh-CN",
research_topic="原始研究主题",
clarification_rounds=2,
)
current_step = MagicMock()
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template"
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify locale was passed to template
call_args = mock_get_system_prompt.call_args
assert call_args[0][1]["locale"] == "zh-CN"
@pytest.mark.asyncio
async def test_fallback_raises_on_llm_failure(self):
"""Test that fallback raises exception when LLM call fails."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(side_effect=Exception("LLM API error"))
mock_get_llm.return_value = mock_llm
# Should raise the exception
with pytest.raises(Exception, match="LLM API error"):
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
@pytest.mark.asyncio
async def test_fallback_handles_different_agent_types(self):
"""Test that fallback works with different agent types."""
state = State(messages=[], locale="en-US")
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Agent summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
for agent_name in ["researcher", "coder", "analyst"]:
current_step = MagicMock()
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name=agent_name,
current_step=current_step,
state=state,
)
# Verify agent name is set correctly
assert result[-1].name == agent_name
@pytest.mark.asyncio
async def test_fallback_uses_partial_agent_messages(self):
"""Test that fallback includes partial agent messages in result."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create partial agent messages with tool calls
# Use proper tool_call format
from langchain_core.messages import ToolCall
tool_call = ToolCall(
name="web_search",
args={"query": "test query"},
id="123"
)
partial_agent_messages = [
HumanMessage(content="Input message"),
AIMessage(content="", tool_calls=[tool_call]),
HumanMessage(content="Tool result: Search completed"),
]
mock_llm_response = MagicMock()
mock_llm_response.content = "Fallback summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify partial messages are in result
result_messages = result
assert len(result_messages) == len(partial_agent_messages) + 1
# First messages should be from partial_agent_messages
assert result_messages[0] == partial_agent_messages[0]
assert result_messages[1] == partial_agent_messages[1]
assert result_messages[2] == partial_agent_messages[2]
# Last message should be the fallback AI response
assert isinstance(result_messages[3], AIMessage)
assert result_messages[3].content == "Fallback summary"
@pytest.mark.asyncio
async def test_fallback_handles_empty_partial_messages(self):
"""Test that fallback handles empty partial_agent_messages."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = [] # Empty
mock_llm_response = MagicMock()
mock_llm_response.content = "Fallback summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list (early return)
assert result == []
# Verify get_llm_by_type was NOT called (early return)
mock_get_llm.assert_not_called()
class TestRecursionFallbackConfiguration:
"""Test suite for enable_recursion_fallback configuration."""
def test_config_default_is_enabled(self):
"""Test that enable_recursion_fallback defaults to True."""
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_true(self):
"""Test that enable_recursion_fallback can be set via environment variable."""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "true"}):
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_false(self):
"""Test that enable_recursion_fallback can be disabled via environment variable.
NOTE: This test documents the current behavior. The Configuration.from_runnable_config
method has a known issue where it doesn't properly convert boolean strings like "false"
to boolean False. This test reflects the actual (buggy) behavior and should be updated
when the Configuration class is fixed to use get_bool_env for boolean fields.
"""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "false"}):
config = Configuration()
# Currently returns True due to Configuration class bug
# Should return False when using get_bool_env properly
assert config.enable_recursion_fallback is True # Actual behavior
def test_config_from_env_variable_1(self):
"""Test that '1' is treated as True for enable_recursion_fallback."""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "1"}):
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_0(self):
"""Test that '0' is treated as False for enable_recursion_fallback.
NOTE: This test documents the current behavior. The Configuration class has a known
issue where string "0" is not properly converted to boolean False.
"""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "0"}):
config = Configuration()
# Currently returns True due to Configuration class bug
assert config.enable_recursion_fallback is True # Actual behavior
def test_config_from_env_variable_yes(self):
"""Test that 'yes' is treated as True for enable_recursion_fallback."""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "yes"}):
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_no(self):
"""Test that 'no' is treated as False for enable_recursion_fallback.
NOTE: This test documents the current behavior. The Configuration class has a known
issue where string "no" is not properly converted to boolean False.
"""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "no"}):
config = Configuration()
# Currently returns True due to Configuration class bug
assert config.enable_recursion_fallback is True # Actual behavior
def test_config_from_runnable_config(self):
"""Test that enable_recursion_fallback can be set via RunnableConfig."""
from langchain_core.runnables import RunnableConfig
# Test with False value
config_false = RunnableConfig(configurable={"enable_recursion_fallback": False})
configuration_false = Configuration.from_runnable_config(config_false)
assert configuration_false.enable_recursion_fallback is False
# Test with True value
config_true = RunnableConfig(configurable={"enable_recursion_fallback": True})
configuration_true = Configuration.from_runnable_config(config_true)
assert configuration_true.enable_recursion_fallback is True
def test_config_field_exists(self):
"""Test that enable_recursion_fallback field exists in Configuration."""
config = Configuration()
assert hasattr(config, "enable_recursion_fallback")
assert isinstance(config.enable_recursion_fallback, bool)
class TestRecursionFallbackIntegration:
"""Integration tests for recursion fallback in agent execution."""
@pytest.mark.asyncio
async def test_fallback_function_signature_returns_list(self):
"""Test that the fallback function returns a list of messages."""
from src.graph.nodes import _handle_recursion_limit_fallback
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
# This should not raise - just verify the function returns a list
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify it returns a list
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_configuration_enables_disables_fallback(self):
"""Test that configuration controls fallback behavior."""
configurable_enabled = Configuration(enable_recursion_fallback=True)
configurable_disabled = Configuration(enable_recursion_fallback=False)
assert configurable_enabled.enable_recursion_fallback is True
assert configurable_disabled.enable_recursion_fallback is False
class TestRecursionFallbackEdgeCases:
"""Test edge cases and boundary conditions for recursion fallback."""
@pytest.mark.asyncio
async def test_fallback_with_empty_observations(self):
"""Test fallback behavior when there are no observations."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = []
mock_llm_response = MagicMock()
mock_llm_response.content = "No observations available"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list
assert result == []
@pytest.mark.asyncio
async def test_fallback_with_very_long_recursion_limit(self):
"""Test fallback with very large recursion limit value."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = []
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list
assert result == []
@pytest.mark.asyncio
async def test_fallback_with_unicode_locale(self):
"""Test fallback with various locale formats including unicode."""
for locale in ["zh-CN", "ja-JP", "ko-KR", "en-US", "pt-BR"]:
state = State(messages=[], locale=locale)
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = f"Summary for {locale}"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template"
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify locale was passed to template
call_args = mock_get_system_prompt.call_args
assert call_args[0][1]["locale"] == locale
@pytest.mark.asyncio
async def test_fallback_with_none_locale(self):
"""Test fallback handles None locale gracefully."""
state = State(messages=[], locale=None)
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template"
# Should not raise, should use default locale
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify default locale "en-US" was used
call_args = mock_get_system_prompt.call_args
assert call_args[0][1]["locale"] is None or call_args[0][1]["locale"] == "en-US"