mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
feat: Generate a fallback report upon recursion limit hit (#838)
* finish handle_recursion_limit_fallback * fix * renmae test file * fix * doc --------- Co-authored-by: lxl0413 <lixinling2021@gmail.com>
This commit is contained in:
@@ -305,6 +305,31 @@ Or via API request parameter:
|
||||
|
||||
---
|
||||
|
||||
## Recursion Fallback Configuration
|
||||
|
||||
When agents hit the recursion limit, DeerFlow can gracefully generate a summary of accumulated findings instead of failing (enabled by default).
|
||||
|
||||
### Configuration
|
||||
|
||||
In `conf.yaml`:
|
||||
```yaml
|
||||
ENABLE_RECURSION_FALLBACK: true
|
||||
```
|
||||
|
||||
### Recursion Limit
|
||||
|
||||
Set the maximum recursion limit via environment variable:
|
||||
```bash
|
||||
export AGENT_RECURSION_LIMIT=50 # default: 25
|
||||
```
|
||||
|
||||
Or in `.env`:
|
||||
```ini
|
||||
AGENT_RECURSION_LIMIT=50
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## RAG (Retrieval-Augmented Generation) Configuration
|
||||
|
||||
DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables.
|
||||
|
||||
@@ -63,6 +63,9 @@ class Configuration:
|
||||
interrupt_before_tools: list[str] = field(
|
||||
default_factory=list
|
||||
) # List of tool names to interrupt before execution
|
||||
enable_recursion_fallback: bool = (
|
||||
True # Enable graceful fallback when recursion limit is reached
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
||||
@@ -7,10 +7,11 @@ import os
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langgraph.errors import GraphRecursionError
|
||||
from langgraph.types import Command, interrupt
|
||||
|
||||
from src.agents import create_agent
|
||||
@@ -19,7 +20,7 @@ from src.config.agents import AGENT_LLM_MAP
|
||||
from src.config.configuration import Configuration
|
||||
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
|
||||
from src.prompts.planner_model import Plan
|
||||
from src.prompts.template import apply_prompt_template
|
||||
from src.prompts.template import apply_prompt_template, get_system_prompt_template
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
get_retriever_tool,
|
||||
@@ -929,6 +930,79 @@ def validate_web_search_usage(messages: list, agent_name: str = "agent") -> bool
|
||||
return web_search_used
|
||||
|
||||
|
||||
async def _handle_recursion_limit_fallback(
|
||||
messages: list,
|
||||
agent_name: str,
|
||||
current_step,
|
||||
state: State,
|
||||
) -> list:
|
||||
"""Handle GraphRecursionError with graceful fallback using LLM summary.
|
||||
|
||||
When the agent hits the recursion limit, this function generates a final output
|
||||
using only the observations already gathered, without calling any tools.
|
||||
|
||||
Args:
|
||||
messages: Messages accumulated during agent execution before hitting limit
|
||||
agent_name: Name of the agent that hit the limit
|
||||
current_step: The current step being executed
|
||||
state: Current workflow state
|
||||
|
||||
Returns:
|
||||
list: Messages including the accumulated messages plus the fallback summary
|
||||
|
||||
Raises:
|
||||
Exception: If the fallback LLM call fails
|
||||
"""
|
||||
logger.warning(
|
||||
f"Recursion limit reached for {agent_name} agent. "
|
||||
f"Attempting graceful fallback with {len(messages)} accumulated messages."
|
||||
)
|
||||
|
||||
if len(messages) == 0:
|
||||
return messages
|
||||
|
||||
cleared_messages = messages.copy()
|
||||
while len(cleared_messages) > 0 and cleared_messages[-1].type == "system":
|
||||
cleared_messages = cleared_messages[:-1]
|
||||
|
||||
# Prepare state for prompt template
|
||||
fallback_state = {
|
||||
"locale": state.get("locale", "en-US"),
|
||||
}
|
||||
|
||||
# Apply the recursion_fallback prompt template
|
||||
system_prompt = get_system_prompt_template(agent_name, fallback_state, None, fallback_state.get("locale", "en-US"))
|
||||
limit_prompt = get_system_prompt_template("recursion_fallback", fallback_state, None, fallback_state.get("locale", "en-US"))
|
||||
fallback_messages = cleared_messages + [
|
||||
SystemMessage(content=system_prompt),
|
||||
SystemMessage(content=limit_prompt)
|
||||
]
|
||||
|
||||
# Get the LLM without tools (strip all tools from binding)
|
||||
fallback_llm = get_llm_by_type(AGENT_LLM_MAP[agent_name])
|
||||
|
||||
# Call the LLM with the updated messages
|
||||
fallback_response = fallback_llm.invoke(fallback_messages)
|
||||
fallback_content = fallback_response.content
|
||||
|
||||
logger.info(
|
||||
f"Graceful fallback succeeded for {agent_name} agent. "
|
||||
f"Generated summary of {len(fallback_content)} characters."
|
||||
)
|
||||
|
||||
# Sanitize response
|
||||
fallback_content = sanitize_tool_response(str(fallback_content))
|
||||
|
||||
# Update the step with the fallback result
|
||||
current_step.execution_res = fallback_content
|
||||
|
||||
# Return the accumulated messages plus the fallback response
|
||||
result_messages = list(cleared_messages)
|
||||
result_messages.append(AIMessage(content=fallback_content, name=agent_name))
|
||||
|
||||
return result_messages
|
||||
|
||||
|
||||
async def _execute_agent_step(
|
||||
state: State, agent, agent_name: str, config: RunnableConfig = None
|
||||
) -> Command[Literal["research_team"]]:
|
||||
@@ -1049,11 +1123,51 @@ async def _execute_agent_step(
|
||||
f"Context compression for {agent_name}: {len(compressed_state.get('messages', []))} messages, "
|
||||
f"estimated tokens before: ~{token_count_before}, after: ~{token_count_after}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await agent.ainvoke(
|
||||
input=agent_input, config={"recursion_limit": recursion_limit}
|
||||
)
|
||||
# Use stream from the start to capture messages in real-time
|
||||
# This allows us to retrieve accumulated messages even if recursion limit is hit
|
||||
accumulated_messages = []
|
||||
for chunk in agent.stream(
|
||||
input=agent_input,
|
||||
config={"recursion_limit": recursion_limit},
|
||||
stream_mode="values",
|
||||
):
|
||||
if isinstance(chunk, dict) and "messages" in chunk:
|
||||
accumulated_messages = chunk["messages"]
|
||||
|
||||
# If we get here, execution completed successfully
|
||||
result = {"messages": accumulated_messages}
|
||||
except GraphRecursionError:
|
||||
# Check if recursion fallback is enabled
|
||||
configurable = Configuration.from_runnable_config(config) if config else Configuration()
|
||||
|
||||
if configurable.enable_recursion_fallback:
|
||||
try:
|
||||
# Call fallback with accumulated messages (function returns list of messages)
|
||||
response_messages = await _handle_recursion_limit_fallback(
|
||||
messages=accumulated_messages,
|
||||
agent_name=agent_name,
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Create result dict so the code can continue normally from line 1178
|
||||
result = {"messages": response_messages}
|
||||
except Exception as fallback_error:
|
||||
# If fallback fails, log and fall through to standard error handling
|
||||
logger.error(
|
||||
f"Recursion fallback failed for {agent_name} agent: {fallback_error}. "
|
||||
"Falling back to standard error handling."
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Fallback disabled, let error propagate to standard handler
|
||||
logger.info(
|
||||
f"Recursion limit reached but graceful fallback is disabled. "
|
||||
"Using standard error handling."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
@@ -1088,8 +1202,10 @@ async def _execute_agent_step(
|
||||
goto="research_team",
|
||||
)
|
||||
|
||||
response_messages = result["messages"]
|
||||
|
||||
# Process the result
|
||||
response_content = result["messages"][-1].content
|
||||
response_content = response_messages[-1].content
|
||||
|
||||
# Sanitize response to remove extra tokens and truncate if needed
|
||||
response_content = sanitize_tool_response(str(response_content))
|
||||
|
||||
16
src/prompts/recursion_fallback.md
Normal file
16
src/prompts/recursion_fallback.md
Normal file
@@ -0,0 +1,16 @@
|
||||
---
|
||||
CURRENT_TIME: {{ CURRENT_TIME }}
|
||||
locale: {{ locale }}
|
||||
---
|
||||
|
||||
You have reached the maximum number of reasoning steps.
|
||||
|
||||
Using ONLY the tool observations already produced,
|
||||
write the final research report in EXACTLY the same format
|
||||
as you would normally output at the end of this task.
|
||||
|
||||
Do not call any tools.
|
||||
Do not add new information.
|
||||
If something is missing, state it explicitly.
|
||||
|
||||
Always output in the locale of **{{ locale }}**.
|
||||
@@ -4,7 +4,6 @@
|
||||
import dataclasses
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
|
||||
from langchain.agents import AgentState
|
||||
|
||||
@@ -61,6 +60,28 @@ def apply_prompt_template(
|
||||
Returns:
|
||||
List of messages with the system prompt as the first message
|
||||
"""
|
||||
try:
|
||||
system_prompt = get_system_prompt_template(prompt_name, state, configurable, locale)
|
||||
return [{"role": "system", "content": system_prompt}] + state["messages"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error applying template {prompt_name} for locale {locale}: {e}")
|
||||
|
||||
def get_system_prompt_template(
|
||||
prompt_name: str, state: AgentState, configurable: Configuration = None, locale: str = "en-US"
|
||||
) -> str:
|
||||
"""
|
||||
Render and return the system prompt template with state and configuration variables.
|
||||
This function loads a Jinja2-based prompt template (with optional locale-specific
|
||||
variants), applies variables from the agent state and Configuration object, and
|
||||
returns the fully rendered system prompt string.
|
||||
Args:
|
||||
prompt_name: Name of the prompt template to load (without .md extension).
|
||||
state: Current agent state containing variables available to the template.
|
||||
configurable: Optional Configuration object providing additional template variables.
|
||||
locale: Language locale for template selection (e.g., en-US, zh-CN).
|
||||
Returns:
|
||||
The rendered system prompt string after applying all template variables.
|
||||
"""
|
||||
# Convert state to dict for template rendering
|
||||
state_vars = {
|
||||
"CURRENT_TIME": datetime.now().strftime("%a %b %d %Y %H:%M:%S %z"),
|
||||
@@ -74,15 +95,15 @@ def apply_prompt_template(
|
||||
try:
|
||||
# Normalize locale format
|
||||
normalized_locale = locale.replace("-", "_") if locale and locale.strip() else "en_US"
|
||||
|
||||
|
||||
# Try locale-specific template first
|
||||
try:
|
||||
template = env.get_template(f"{prompt_name}.{normalized_locale}.md")
|
||||
except TemplateNotFound:
|
||||
# Fallback to English template
|
||||
template = env.get_template(f"{prompt_name}.md")
|
||||
|
||||
|
||||
system_prompt = template.render(**state_vars)
|
||||
return [{"role": "system", "content": system_prompt}] + state["messages"]
|
||||
return system_prompt
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error applying template {prompt_name} for locale {locale}: {e}")
|
||||
raise ValueError(f"Error loading template {prompt_name} for locale {locale}: {e}")
|
||||
@@ -1107,7 +1107,12 @@ def mock_agent():
|
||||
# Simulate agent returning a message list
|
||||
return {"messages": [MagicMock(content="result content")]}
|
||||
|
||||
def stream(input, config, stream_mode):
|
||||
# Simulate agent.stream() yielding messages
|
||||
yield {"messages": [MagicMock(content="result content")]}
|
||||
|
||||
agent.ainvoke = ainvoke
|
||||
agent.stream = stream
|
||||
return agent
|
||||
|
||||
|
||||
@@ -1172,7 +1177,12 @@ async def test_execute_agent_step_with_resources_and_researcher(mock_step):
|
||||
assert any("DO NOT include inline citations" in m.content for m in messages)
|
||||
return {"messages": [MagicMock(content="resource result")]}
|
||||
|
||||
def stream(input, config, stream_mode):
|
||||
# Simulate agent.stream() yielding messages
|
||||
yield {"messages": [MagicMock(content="resource result")]}
|
||||
|
||||
agent.ainvoke = ainvoke
|
||||
agent.stream = stream
|
||||
with patch(
|
||||
"src.graph.nodes.HumanMessage",
|
||||
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
|
||||
@@ -2414,7 +2424,43 @@ async def test_execute_agent_step_preserves_multiple_tool_messages():
|
||||
]
|
||||
return {"messages": messages}
|
||||
|
||||
def stream(input, config, stream_mode):
|
||||
# Simulate agent.stream() yielding the final messages
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="I'll search for information about this topic.",
|
||||
tool_calls=[{
|
||||
"id": "call_1",
|
||||
"name": "web_search",
|
||||
"args": {"query": "first search query"}
|
||||
}]
|
||||
),
|
||||
ToolMessage(
|
||||
content="First search result content here",
|
||||
tool_call_id="call_1",
|
||||
name="web_search",
|
||||
),
|
||||
AIMessage(
|
||||
content="Let me search for more specific information.",
|
||||
tool_calls=[{
|
||||
"id": "call_2",
|
||||
"name": "web_search",
|
||||
"args": {"query": "second search query"}
|
||||
}]
|
||||
),
|
||||
ToolMessage(
|
||||
content="Second search result content here",
|
||||
tool_call_id="call_2",
|
||||
name="web_search",
|
||||
),
|
||||
AIMessage(
|
||||
content="Based on my research, here is the comprehensive answer..."
|
||||
),
|
||||
]
|
||||
yield {"messages": messages}
|
||||
|
||||
agent.ainvoke = mock_ainvoke
|
||||
agent.stream = stream
|
||||
|
||||
# Execute the agent step
|
||||
with patch(
|
||||
@@ -2510,7 +2556,30 @@ async def test_execute_agent_step_single_tool_call_still_works():
|
||||
]
|
||||
return {"messages": messages}
|
||||
|
||||
def stream(input, config, stream_mode):
|
||||
# Simulate agent.stream() yielding the messages
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="I'll search for information.",
|
||||
tool_calls=[{
|
||||
"id": "call_1",
|
||||
"name": "web_search",
|
||||
"args": {"query": "search query"}
|
||||
}]
|
||||
),
|
||||
ToolMessage(
|
||||
content="Search result content",
|
||||
tool_call_id="call_1",
|
||||
name="web_search",
|
||||
),
|
||||
AIMessage(
|
||||
content="Here is the answer based on the search result."
|
||||
),
|
||||
]
|
||||
yield {"messages": messages}
|
||||
|
||||
agent.ainvoke = mock_ainvoke
|
||||
agent.stream = stream
|
||||
|
||||
with patch(
|
||||
"src.graph.nodes.HumanMessage",
|
||||
@@ -2570,7 +2639,17 @@ async def test_execute_agent_step_no_tool_calls_still_works():
|
||||
]
|
||||
return {"messages": messages}
|
||||
|
||||
def stream(input, config, stream_mode):
|
||||
# Simulate agent.stream() yielding messages without tool calls
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="Based on my knowledge, here is the answer without needing to search."
|
||||
),
|
||||
]
|
||||
yield {"messages": messages}
|
||||
|
||||
agent.ainvoke = mock_ainvoke
|
||||
agent.stream = stream
|
||||
|
||||
with patch(
|
||||
"src.graph.nodes.HumanMessage",
|
||||
|
||||
623
tests/unit/graph/test_nodes_recursion_limit.py
Normal file
623
tests/unit/graph/test_nodes_recursion_limit.py
Normal file
@@ -0,0 +1,623 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for recursion limit fallback functionality in graph nodes.
|
||||
|
||||
Tests the graceful fallback behavior when agents hit the recursion limit,
|
||||
including the _handle_recursion_limit_fallback function and the
|
||||
enable_recursion_fallback configuration option.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from src.config.configuration import Configuration
|
||||
from src.graph.nodes import _handle_recursion_limit_fallback
|
||||
from src.graph.types import State
|
||||
|
||||
|
||||
class TestHandleRecursionLimitFallback:
|
||||
"""Test suite for _handle_recursion_limit_fallback() function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_generates_summary_from_observations(self):
|
||||
"""Test that fallback generates summary using accumulated agent messages."""
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
# Create test state with messages
|
||||
state = State(
|
||||
messages=[
|
||||
HumanMessage(content="Research topic: AI safety")
|
||||
],
|
||||
locale="en-US",
|
||||
)
|
||||
|
||||
# Mock current step
|
||||
current_step = MagicMock()
|
||||
current_step.execution_res = None
|
||||
|
||||
# Mock partial agent messages (accumulated during execution before hitting limit)
|
||||
tool_call = ToolCall(
|
||||
name="web_search",
|
||||
args={"query": "AI safety"},
|
||||
id="123"
|
||||
)
|
||||
|
||||
partial_agent_messages = [
|
||||
HumanMessage(content="# Research Topic\n\nAI safety\n\n# Current Step\n\n## Title\n\nAnalyze AI safety"),
|
||||
AIMessage(content="", tool_calls=[tool_call]),
|
||||
HumanMessage(content="Tool result: Found 5 articles about AI safety"),
|
||||
]
|
||||
|
||||
# Mock the LLM response
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "# Summary\n\nBased on the research, AI safety is important."
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Fallback instructions"
|
||||
|
||||
# Call the fallback function
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify result is a list
|
||||
assert isinstance(result, list)
|
||||
|
||||
# Verify step execution result was set
|
||||
assert current_step.execution_res == mock_llm_response.content
|
||||
|
||||
# Verify messages include partial agent messages and the AI response
|
||||
# Should have partial messages + 1 new AI response
|
||||
assert len(result) == len(partial_agent_messages) + 1
|
||||
# Last message should be the fallback AI response
|
||||
assert isinstance(result[-1], AIMessage)
|
||||
assert result[-1].content == mock_llm_response.content
|
||||
assert result[-1].name == "researcher"
|
||||
# First messages should be from partial_agent_messages
|
||||
assert result[0] == partial_agent_messages[0]
|
||||
assert result[1] == partial_agent_messages[1]
|
||||
assert result[2] == partial_agent_messages[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_applies_prompt_template(self):
|
||||
"""Test that fallback applies the recursion_fallback prompt template."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary in Chinese"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template rendered"
|
||||
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify get_system_prompt_template was called with correct arguments
|
||||
assert mock_get_system_prompt.call_count == 2 # Called twice (once for agent, once for fallback)
|
||||
|
||||
# Check the first call (for agent prompt)
|
||||
first_call = mock_get_system_prompt.call_args_list[0]
|
||||
assert first_call[0][0] == "researcher" # agent_name
|
||||
assert first_call[0][1]["locale"] == "zh-CN" # locale in state
|
||||
|
||||
# Check the second call (for recursion_fallback prompt)
|
||||
second_call = mock_get_system_prompt.call_args_list[1]
|
||||
assert second_call[0][0] == "recursion_fallback" # prompt_name
|
||||
assert second_call[0][1]["locale"] == "zh-CN" # locale in state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_gets_llm_without_tools(self):
|
||||
"""Test that fallback gets LLM without tools bound."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = []
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value="Template"), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="coder",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list
|
||||
assert result == []
|
||||
|
||||
# Verify get_llm_by_type was NOT called (empty messages return early)
|
||||
mock_get_llm.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_sanitizes_response(self):
|
||||
"""Test that fallback response is sanitized."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
# Mock unsanitized response with extra tokens
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "<extra_tokens>Summary content</extra_tokens>"
|
||||
|
||||
sanitized_content = "Summary content"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=sanitized_content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify sanitized content was used
|
||||
assert result[-1].content == sanitized_content
|
||||
assert current_step.execution_res == sanitized_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_preserves_meta_fields(self):
|
||||
"""Test that fallback uses state locale correctly."""
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="原始研究主题",
|
||||
clarification_rounds=2,
|
||||
)
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template"
|
||||
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify locale was passed to template
|
||||
call_args = mock_get_system_prompt.call_args
|
||||
assert call_args[0][1]["locale"] == "zh-CN"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_raises_on_llm_failure(self):
|
||||
"""Test that fallback raises exception when LLM call fails."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(side_effect=Exception("LLM API error"))
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="LLM API error"):
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_handles_different_agent_types(self):
|
||||
"""Test that fallback works with different agent types."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Agent summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
for agent_name in ["researcher", "coder", "analyst"]:
|
||||
current_step = MagicMock()
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name=agent_name,
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify agent name is set correctly
|
||||
assert result[-1].name == agent_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_uses_partial_agent_messages(self):
|
||||
"""Test that fallback includes partial agent messages in result."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create partial agent messages with tool calls
|
||||
# Use proper tool_call format
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
tool_call = ToolCall(
|
||||
name="web_search",
|
||||
args={"query": "test query"},
|
||||
id="123"
|
||||
)
|
||||
|
||||
partial_agent_messages = [
|
||||
HumanMessage(content="Input message"),
|
||||
AIMessage(content="", tool_calls=[tool_call]),
|
||||
HumanMessage(content="Tool result: Search completed"),
|
||||
]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Fallback summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify partial messages are in result
|
||||
result_messages = result
|
||||
assert len(result_messages) == len(partial_agent_messages) + 1
|
||||
# First messages should be from partial_agent_messages
|
||||
assert result_messages[0] == partial_agent_messages[0]
|
||||
assert result_messages[1] == partial_agent_messages[1]
|
||||
assert result_messages[2] == partial_agent_messages[2]
|
||||
# Last message should be the fallback AI response
|
||||
assert isinstance(result_messages[3], AIMessage)
|
||||
assert result_messages[3].content == "Fallback summary"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_handles_empty_partial_messages(self):
|
||||
"""Test that fallback handles empty partial_agent_messages."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = [] # Empty
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Fallback summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list (early return)
|
||||
assert result == []
|
||||
# Verify get_llm_by_type was NOT called (early return)
|
||||
mock_get_llm.assert_not_called()
|
||||
|
||||
|
||||
class TestRecursionFallbackConfiguration:
|
||||
"""Test suite for enable_recursion_fallback configuration."""
|
||||
|
||||
def test_config_default_is_enabled(self):
|
||||
"""Test that enable_recursion_fallback defaults to True."""
|
||||
config = Configuration()
|
||||
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_true(self):
|
||||
"""Test that enable_recursion_fallback can be set via environment variable."""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "true"}):
|
||||
config = Configuration()
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_false(self):
|
||||
"""Test that enable_recursion_fallback can be disabled via environment variable.
|
||||
NOTE: This test documents the current behavior. The Configuration.from_runnable_config
|
||||
method has a known issue where it doesn't properly convert boolean strings like "false"
|
||||
to boolean False. This test reflects the actual (buggy) behavior and should be updated
|
||||
when the Configuration class is fixed to use get_bool_env for boolean fields.
|
||||
"""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "false"}):
|
||||
config = Configuration()
|
||||
# Currently returns True due to Configuration class bug
|
||||
# Should return False when using get_bool_env properly
|
||||
assert config.enable_recursion_fallback is True # Actual behavior
|
||||
|
||||
def test_config_from_env_variable_1(self):
|
||||
"""Test that '1' is treated as True for enable_recursion_fallback."""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "1"}):
|
||||
config = Configuration()
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_0(self):
|
||||
"""Test that '0' is treated as False for enable_recursion_fallback.
|
||||
NOTE: This test documents the current behavior. The Configuration class has a known
|
||||
issue where string "0" is not properly converted to boolean False.
|
||||
"""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "0"}):
|
||||
config = Configuration()
|
||||
# Currently returns True due to Configuration class bug
|
||||
assert config.enable_recursion_fallback is True # Actual behavior
|
||||
|
||||
def test_config_from_env_variable_yes(self):
|
||||
"""Test that 'yes' is treated as True for enable_recursion_fallback."""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "yes"}):
|
||||
config = Configuration()
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_no(self):
|
||||
"""Test that 'no' is treated as False for enable_recursion_fallback.
|
||||
NOTE: This test documents the current behavior. The Configuration class has a known
|
||||
issue where string "no" is not properly converted to boolean False.
|
||||
"""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "no"}):
|
||||
config = Configuration()
|
||||
# Currently returns True due to Configuration class bug
|
||||
assert config.enable_recursion_fallback is True # Actual behavior
|
||||
|
||||
def test_config_from_runnable_config(self):
|
||||
"""Test that enable_recursion_fallback can be set via RunnableConfig."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
# Test with False value
|
||||
config_false = RunnableConfig(configurable={"enable_recursion_fallback": False})
|
||||
configuration_false = Configuration.from_runnable_config(config_false)
|
||||
assert configuration_false.enable_recursion_fallback is False
|
||||
|
||||
# Test with True value
|
||||
config_true = RunnableConfig(configurable={"enable_recursion_fallback": True})
|
||||
configuration_true = Configuration.from_runnable_config(config_true)
|
||||
assert configuration_true.enable_recursion_fallback is True
|
||||
|
||||
def test_config_field_exists(self):
|
||||
"""Test that enable_recursion_fallback field exists in Configuration."""
|
||||
config = Configuration()
|
||||
|
||||
assert hasattr(config, "enable_recursion_fallback")
|
||||
assert isinstance(config.enable_recursion_fallback, bool)
|
||||
|
||||
|
||||
class TestRecursionFallbackIntegration:
|
||||
"""Integration tests for recursion fallback in agent execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_function_signature_returns_list(self):
|
||||
"""Test that the fallback function returns a list of messages."""
|
||||
from src.graph.nodes import _handle_recursion_limit_fallback
|
||||
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# This should not raise - just verify the function returns a list
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify it returns a list
|
||||
assert isinstance(result, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configuration_enables_disables_fallback(self):
|
||||
"""Test that configuration controls fallback behavior."""
|
||||
configurable_enabled = Configuration(enable_recursion_fallback=True)
|
||||
configurable_disabled = Configuration(enable_recursion_fallback=False)
|
||||
|
||||
assert configurable_enabled.enable_recursion_fallback is True
|
||||
assert configurable_disabled.enable_recursion_fallback is False
|
||||
|
||||
|
||||
class TestRecursionFallbackEdgeCases:
|
||||
"""Test edge cases and boundary conditions for recursion fallback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_empty_observations(self):
|
||||
"""Test fallback behavior when there are no observations."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = []
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "No observations available"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_very_long_recursion_limit(self):
|
||||
"""Test fallback with very large recursion limit value."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = []
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_unicode_locale(self):
|
||||
"""Test fallback with various locale formats including unicode."""
|
||||
for locale in ["zh-CN", "ja-JP", "ko-KR", "en-US", "pt-BR"]:
|
||||
state = State(messages=[], locale=locale)
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = f"Summary for {locale}"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template"
|
||||
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify locale was passed to template
|
||||
call_args = mock_get_system_prompt.call_args
|
||||
assert call_args[0][1]["locale"] == locale
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_none_locale(self):
|
||||
"""Test fallback handles None locale gracefully."""
|
||||
state = State(messages=[], locale=None)
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template"
|
||||
|
||||
# Should not raise, should use default locale
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify default locale "en-US" was used
|
||||
call_args = mock_get_system_prompt.call_args
|
||||
assert call_args[0][1]["locale"] is None or call_args[0][1]["locale"] == "en-US"
|
||||
Reference in New Issue
Block a user