feat(client): support agent_name injection to enable isolated memory and custom prompts (#1253)

* feat(client): 添加agent_name参数支持自定义代理名称

允许在初始化DeerFlowClient时指定代理名称,该名称将用于中间件构建和系统提示模板

* test: add coverage for agent_name parameter in DeerFlowClient

* fix(client): address PR review comments for agent_name injection

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
knukn
2026-03-23 17:44:21 +08:00
committed by GitHub
parent f6c54e0308
commit fe75cb35ca
2 changed files with 47 additions and 4 deletions

View File

@@ -37,6 +37,7 @@ from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import get_app_config, reload_app_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.paths import get_paths
@@ -106,6 +107,7 @@ class DeerFlowClient:
thinking_enabled: bool = True,
subagent_enabled: bool = False,
plan_mode: bool = False,
agent_name: str | None = None,
):
"""Initialize the client.
@@ -120,16 +122,21 @@ class DeerFlowClient:
thinking_enabled: Enable model's extended thinking.
subagent_enabled: Enable subagent delegation.
plan_mode: Enable TodoList middleware for plan mode.
agent_name: Name of the agent to use.
"""
if config_path is not None:
reload_app_config(config_path)
self._app_config = get_app_config()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
self._checkpointer = checkpointer
self._model_name = model_name
self._thinking_enabled = thinking_enabled
self._subagent_enabled = subagent_enabled
self._plan_mode = plan_mode
self._agent_name = agent_name
# Lazy agent — created on first call, recreated when config changes.
self._agent = None
@@ -202,10 +209,11 @@ class DeerFlowClient:
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
),
"state_schema": ThreadState,
}
@@ -219,7 +227,7 @@ class DeerFlowClient:
self._agent = create_agent(**kwargs)
self._agent_config_key = key
logger.info("Agent created: model=%s, thinking=%s", model_name, thinking_enabled)
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
@@ -338,6 +346,8 @@ class DeerFlowClient:
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id}
if self._agent_name:
context["agent_name"] = self._agent_name
seen_ids: set[str] = set()
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}

View File

@@ -56,6 +56,7 @@ class TestClientInit:
assert client._thinking_enabled is True
assert client._subagent_enabled is False
assert client._plan_mode is False
assert client._agent_name is None
assert client._checkpointer is None
assert client._agent is None
@@ -66,11 +67,20 @@ class TestClientInit:
thinking_enabled=False,
subagent_enabled=True,
plan_mode=True,
agent_name="test-agent"
)
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
assert c._subagent_enabled is True
assert c._plan_mode is True
assert c._agent_name == "test-agent"
def test_invalid_agent_name(self, mock_app_config):
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="invalid name with spaces!")
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="../path/traversal")
def test_custom_config_path(self, mock_app_config):
with (
@@ -188,6 +198,23 @@ class TestStream:
msg_events = _ai_events(events)
assert msg_events[0].data["content"] == "Hello!"
def test_context_propagation(self, client):
"""stream() passes agent_name to the context."""
agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}])
client._agent_name = "test-agent-1"
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
list(client.stream("hi", thread_id="t1"))
# Verify context passed to agent.stream
agent.stream.assert_called_once()
call_kwargs = agent.stream.call_args.kwargs
assert call_kwargs["context"]["thread_id"] == "t1"
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
def test_tool_call_and_result(self, client):
"""stream() emits messages-tuple events for tool calls and results."""
ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}])
@@ -359,13 +386,19 @@ class TestEnsureAgent:
with (
patch("deerflow.client.create_chat_model"),
patch("deerflow.client.create_agent", return_value=mock_agent),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch.object(client, "_get_tools", return_value=[]),
):
client._agent_name = "custom-agent"
client._ensure_agent(config)
assert client._agent is mock_agent
# Verify agent_name propagation
mock_build_middlewares.assert_called_once()
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
mock_apply_prompt.assert_called_once()
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()