feat: Generate a fallback report upon recursion limit hit (#838)

* finish handle_recursion_limit_fallback

* fix

* renmae test file

* fix

* doc

---------

Co-authored-by: lxl0413 <lixinling2021@gmail.com>
This commit is contained in:
Xun
2026-01-26 21:10:18 +08:00
committed by GitHub
parent 9a34e32252
commit ee02b9f637
7 changed files with 895 additions and 12 deletions

View File

@@ -7,10 +7,11 @@ import os
from functools import partial
from typing import Annotated, Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.errors import GraphRecursionError
from langgraph.types import Command, interrupt
from src.agents import create_agent
@@ -19,7 +20,7 @@ from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template
from src.prompts.template import apply_prompt_template, get_system_prompt_template
from src.tools import (
crawl_tool,
get_retriever_tool,
@@ -929,6 +930,79 @@ def validate_web_search_usage(messages: list, agent_name: str = "agent") -> bool
return web_search_used
async def _handle_recursion_limit_fallback(
messages: list,
agent_name: str,
current_step,
state: State,
) -> list:
"""Handle GraphRecursionError with graceful fallback using LLM summary.
When the agent hits the recursion limit, this function generates a final output
using only the observations already gathered, without calling any tools.
Args:
messages: Messages accumulated during agent execution before hitting limit
agent_name: Name of the agent that hit the limit
current_step: The current step being executed
state: Current workflow state
Returns:
list: Messages including the accumulated messages plus the fallback summary
Raises:
Exception: If the fallback LLM call fails
"""
logger.warning(
f"Recursion limit reached for {agent_name} agent. "
f"Attempting graceful fallback with {len(messages)} accumulated messages."
)
if len(messages) == 0:
return messages
cleared_messages = messages.copy()
while len(cleared_messages) > 0 and cleared_messages[-1].type == "system":
cleared_messages = cleared_messages[:-1]
# Prepare state for prompt template
fallback_state = {
"locale": state.get("locale", "en-US"),
}
# Apply the recursion_fallback prompt template
system_prompt = get_system_prompt_template(agent_name, fallback_state, None, fallback_state.get("locale", "en-US"))
limit_prompt = get_system_prompt_template("recursion_fallback", fallback_state, None, fallback_state.get("locale", "en-US"))
fallback_messages = cleared_messages + [
SystemMessage(content=system_prompt),
SystemMessage(content=limit_prompt)
]
# Get the LLM without tools (strip all tools from binding)
fallback_llm = get_llm_by_type(AGENT_LLM_MAP[agent_name])
# Call the LLM with the updated messages
fallback_response = fallback_llm.invoke(fallback_messages)
fallback_content = fallback_response.content
logger.info(
f"Graceful fallback succeeded for {agent_name} agent. "
f"Generated summary of {len(fallback_content)} characters."
)
# Sanitize response
fallback_content = sanitize_tool_response(str(fallback_content))
# Update the step with the fallback result
current_step.execution_res = fallback_content
# Return the accumulated messages plus the fallback response
result_messages = list(cleared_messages)
result_messages.append(AIMessage(content=fallback_content, name=agent_name))
return result_messages
async def _execute_agent_step(
state: State, agent, agent_name: str, config: RunnableConfig = None
) -> Command[Literal["research_team"]]:
@@ -1049,11 +1123,51 @@ async def _execute_agent_step(
f"Context compression for {agent_name}: {len(compressed_state.get('messages', []))} messages, "
f"estimated tokens before: ~{token_count_before}, after: ~{token_count_after}"
)
try:
result = await agent.ainvoke(
input=agent_input, config={"recursion_limit": recursion_limit}
)
# Use stream from the start to capture messages in real-time
# This allows us to retrieve accumulated messages even if recursion limit is hit
accumulated_messages = []
for chunk in agent.stream(
input=agent_input,
config={"recursion_limit": recursion_limit},
stream_mode="values",
):
if isinstance(chunk, dict) and "messages" in chunk:
accumulated_messages = chunk["messages"]
# If we get here, execution completed successfully
result = {"messages": accumulated_messages}
except GraphRecursionError:
# Check if recursion fallback is enabled
configurable = Configuration.from_runnable_config(config) if config else Configuration()
if configurable.enable_recursion_fallback:
try:
# Call fallback with accumulated messages (function returns list of messages)
response_messages = await _handle_recursion_limit_fallback(
messages=accumulated_messages,
agent_name=agent_name,
current_step=current_step,
state=state,
)
# Create result dict so the code can continue normally from line 1178
result = {"messages": response_messages}
except Exception as fallback_error:
# If fallback fails, log and fall through to standard error handling
logger.error(
f"Recursion fallback failed for {agent_name} agent: {fallback_error}. "
"Falling back to standard error handling."
)
raise
else:
# Fallback disabled, let error propagate to standard handler
logger.info(
f"Recursion limit reached but graceful fallback is disabled. "
"Using standard error handling."
)
raise
except Exception as e:
import traceback
@@ -1088,8 +1202,10 @@ async def _execute_agent_step(
goto="research_team",
)
response_messages = result["messages"]
# Process the result
response_content = result["messages"][-1].content
response_content = response_messages[-1].content
# Sanitize response to remove extra tokens and truncate if needed
response_content = sanitize_tool_response(str(response_content))