mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-20 04:44:46 +08:00
feat(client): support agent_name injection to enable isolated memory and custom prompts (#1253)
* feat(client): 添加agent_name参数支持自定义代理名称 允许在初始化DeerFlowClient时指定代理名称,该名称将用于中间件构建和系统提示模板 * test: add coverage for agent_name parameter in DeerFlowClient * fix(client): address PR review comments for agent_name injection --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class TestClientInit:
|
||||
assert client._thinking_enabled is True
|
||||
assert client._subagent_enabled is False
|
||||
assert client._plan_mode is False
|
||||
assert client._agent_name is None
|
||||
assert client._checkpointer is None
|
||||
assert client._agent is None
|
||||
|
||||
@@ -66,11 +67,20 @@ class TestClientInit:
|
||||
thinking_enabled=False,
|
||||
subagent_enabled=True,
|
||||
plan_mode=True,
|
||||
agent_name="test-agent"
|
||||
)
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
assert c._plan_mode is True
|
||||
assert c._agent_name == "test-agent"
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="invalid name with spaces!")
|
||||
with pytest.raises(ValueError, match="Invalid agent name"):
|
||||
DeerFlowClient(agent_name="../path/traversal")
|
||||
|
||||
def test_custom_config_path(self, mock_app_config):
|
||||
with (
|
||||
@@ -188,6 +198,23 @@ class TestStream:
|
||||
msg_events = _ai_events(events)
|
||||
assert msg_events[0].data["content"] == "Hello!"
|
||||
|
||||
def test_context_propagation(self, client):
|
||||
"""stream() passes agent_name to the context."""
|
||||
agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}])
|
||||
|
||||
client._agent_name = "test-agent-1"
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
list(client.stream("hi", thread_id="t1"))
|
||||
|
||||
# Verify context passed to agent.stream
|
||||
agent.stream.assert_called_once()
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert call_kwargs["context"]["thread_id"] == "t1"
|
||||
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
|
||||
|
||||
def test_tool_call_and_result(self, client):
|
||||
"""stream() emits messages-tuple events for tool calls and results."""
|
||||
ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}])
|
||||
@@ -359,13 +386,19 @@ class TestEnsureAgent:
|
||||
with (
|
||||
patch("deerflow.client.create_chat_model"),
|
||||
patch("deerflow.client.create_agent", return_value=mock_agent),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert client._agent is mock_agent
|
||||
# Verify agent_name propagation
|
||||
mock_build_middlewares.assert_called_once()
|
||||
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
mock_apply_prompt.assert_called_once()
|
||||
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
|
||||
def test_uses_default_checkpointer_when_available(self, client):
|
||||
mock_agent = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user