mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-25 15:04:46 +08:00
fix:env AGENT_RECURSION_LIMIT not work (#453)
* fix:env AGENT_RECURSION_LIMIT not work * fix:add test * black tests/unit/config/test_configuration.py --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -10,6 +11,39 @@ from langchain_core.runnables import RunnableConfig
|
|||||||
from src.rag.retriever import Resource
|
from src.rag.retriever import Resource
|
||||||
from src.config.report_style import ReportStyle
|
from src.config.report_style import ReportStyle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_recursion_limit(default: int = 25) -> int:
|
||||||
|
"""Get the recursion limit from environment variable or use default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default: Default recursion limit if environment variable is not set or invalid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The recursion limit to use
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
env_value_str = os.getenv("AGENT_RECURSION_LIMIT", str(default))
|
||||||
|
parsed_limit = int(env_value_str)
|
||||||
|
|
||||||
|
if parsed_limit > 0:
|
||||||
|
logger.info(f"Recursion limit set to: {parsed_limit}")
|
||||||
|
return parsed_limit
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"AGENT_RECURSION_LIMIT value '{env_value_str}' (parsed as {parsed_limit}) is not positive. "
|
||||||
|
f"Using default value {default}."
|
||||||
|
)
|
||||||
|
return default
|
||||||
|
except ValueError:
|
||||||
|
raw_env_value = os.getenv("AGENT_RECURSION_LIMIT")
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid AGENT_RECURSION_LIMIT value: '{raw_env_value}'. "
|
||||||
|
f"Using default value {default}."
|
||||||
|
)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class Configuration:
|
class Configuration:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from fastapi.responses import Response, StreamingResponse
|
|||||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
from src.config.report_style import ReportStyle
|
from src.config.report_style import ReportStyle
|
||||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||||
from src.graph.builder import build_graph_with_memory
|
from src.graph.builder import build_graph_with_memory
|
||||||
@@ -150,6 +151,7 @@ async def _astream_workflow_generator(
|
|||||||
"mcp_settings": mcp_settings,
|
"mcp_settings": mcp_settings,
|
||||||
"report_style": report_style.value,
|
"report_style": report_style.value,
|
||||||
"enable_deep_thinking": enable_deep_thinking,
|
"enable_deep_thinking": enable_deep_thinking,
|
||||||
|
"recursion_limit": get_recursion_limit(),
|
||||||
},
|
},
|
||||||
stream_mode=["messages", "updates"],
|
stream_mode=["messages", "updates"],
|
||||||
subgraphs=True,
|
subgraphs=True,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
from src.graph import build_graph
|
from src.graph import build_graph
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
@@ -71,7 +72,7 @@ async def run_agent_workflow_async(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"recursion_limit": 100,
|
"recursion_limit": get_recursion_limit(default=100),
|
||||||
}
|
}
|
||||||
last_message_cnt = 0
|
last_message_cnt = 0
|
||||||
async for s in graph.astream(
|
async for s in graph.astream(
|
||||||
|
|||||||
@@ -88,3 +88,49 @@ def test_from_runnable_config_with_no_config():
|
|||||||
assert config.max_search_results == 3
|
assert config.max_search_results == 3
|
||||||
assert config.resources == []
|
assert config.resources == []
|
||||||
assert config.mcp_settings is None
|
assert config.mcp_settings is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recursion_limit_default():
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|
||||||
|
result = get_recursion_limit()
|
||||||
|
assert result == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recursion_limit_custom_default():
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|
||||||
|
result = get_recursion_limit(50)
|
||||||
|
assert result == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recursion_limit_from_env(monkeypatch):
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|
||||||
|
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "100")
|
||||||
|
result = get_recursion_limit()
|
||||||
|
assert result == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recursion_limit_invalid_env_value(monkeypatch):
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|
||||||
|
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "invalid")
|
||||||
|
result = get_recursion_limit()
|
||||||
|
assert result == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recursion_limit_negative_env_value(monkeypatch):
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|
||||||
|
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5")
|
||||||
|
result = get_recursion_limit()
|
||||||
|
assert result == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recursion_limit_zero_env_value(monkeypatch):
|
||||||
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|
||||||
|
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "0")
|
||||||
|
result = get_recursion_limit()
|
||||||
|
assert result == 25
|
||||||
|
|||||||
Reference in New Issue
Block a user