diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index ebc1609..279f334 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -37,6 +37,7 @@ from langchain_core.runnables import RunnableConfig from deerflow.agents.lead_agent.agent import _build_middlewares from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.thread_state import ThreadState +from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.app_config import get_app_config, reload_app_config from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.paths import get_paths @@ -106,6 +107,7 @@ class DeerFlowClient: thinking_enabled: bool = True, subagent_enabled: bool = False, plan_mode: bool = False, + agent_name: str | None = None, ): """Initialize the client. @@ -120,16 +122,21 @@ class DeerFlowClient: thinking_enabled: Enable model's extended thinking. subagent_enabled: Enable subagent delegation. plan_mode: Enable TodoList middleware for plan mode. + agent_name: Name of the agent to use. """ if config_path is not None: reload_app_config(config_path) self._app_config = get_app_config() + if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name): + raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}") + self._checkpointer = checkpointer self._model_name = model_name self._thinking_enabled = thinking_enabled self._subagent_enabled = subagent_enabled self._plan_mode = plan_mode + self._agent_name = agent_name # Lazy agent — created on first call, recreated when config changes. self._agent = None @@ -202,10 +209,11 @@ class DeerFlowClient: kwargs: dict[str, Any] = { "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), - "middleware": _build_middlewares(config, model_name=model_name), + "middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name), "system_prompt": apply_prompt_template( subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, + agent_name=self._agent_name, ), "state_schema": ThreadState, } @@ -219,7 +227,7 @@ class DeerFlowClient: self._agent = create_agent(**kwargs) self._agent_config_key = key - logger.info("Agent created: model=%s, thinking=%s", model_name, thinking_enabled) + logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled) @staticmethod def _get_tools(*, model_name: str | None, subagent_enabled: bool): @@ -338,6 +346,8 @@ class DeerFlowClient: state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} context = {"thread_id": thread_id} + if self._agent_name: + context["agent_name"] = self._agent_name seen_ids: set[str] = set() cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index f0b5d21..a6bc3b7 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -56,6 +56,7 @@ class TestClientInit: assert client._thinking_enabled is True assert client._subagent_enabled is False assert client._plan_mode is False + assert client._agent_name is None assert client._checkpointer is None assert client._agent is None @@ -66,11 +67,20 @@ class TestClientInit: thinking_enabled=False, subagent_enabled=True, plan_mode=True, + agent_name="test-agent" ) assert c._model_name == "gpt-4" assert c._thinking_enabled is False assert c._subagent_enabled is True assert c._plan_mode is True + assert c._agent_name == "test-agent" + + def test_invalid_agent_name(self, mock_app_config): + with patch("deerflow.client.get_app_config", return_value=mock_app_config): + with pytest.raises(ValueError, match="Invalid agent name"): + DeerFlowClient(agent_name="invalid name with spaces!") + with pytest.raises(ValueError, match="Invalid agent name"): + DeerFlowClient(agent_name="../path/traversal") def test_custom_config_path(self, mock_app_config): with ( @@ -188,6 +198,23 @@ class TestStream: msg_events = _ai_events(events) assert msg_events[0].data["content"] == "Hello!" + def test_context_propagation(self, client): + """stream() passes agent_name to the context.""" + agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}]) + + client._agent_name = "test-agent-1" + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + list(client.stream("hi", thread_id="t1")) + + # Verify context passed to agent.stream + agent.stream.assert_called_once() + call_kwargs = agent.stream.call_args.kwargs + assert call_kwargs["context"]["thread_id"] == "t1" + assert call_kwargs["context"]["agent_name"] == "test-agent-1" + def test_tool_call_and_result(self, client): """stream() emits messages-tuple events for tool calls and results.""" ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}]) @@ -359,13 +386,19 @@ class TestEnsureAgent: with ( patch("deerflow.client.create_chat_model"), patch("deerflow.client.create_agent", return_value=mock_agent), - patch("deerflow.client._build_middlewares", return_value=[]), - patch("deerflow.client.apply_prompt_template", return_value="prompt"), + patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares, + patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt, patch.object(client, "_get_tools", return_value=[]), ): + client._agent_name = "custom-agent" client._ensure_agent(config) assert client._agent is mock_agent + # Verify agent_name propagation + mock_build_middlewares.assert_called_once() + assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent" + mock_apply_prompt.assert_called_once() + assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent" def test_uses_default_checkpointer_when_available(self, client): mock_agent = MagicMock()