test: add Gateway conformance tests for DeerFlowClient (#931)

Validate that all dict-returning client methods conform to Gateway
Pydantic response models (ModelsListResponse, ModelResponse,
SkillsListResponse, SkillResponse, SkillInstallResponse,
McpConfigResponse, UploadResponse, MemoryConfigResponse,
MemoryStatusResponse). Pydantic ValidationError in CI catches
schema drift between client and Gateway with zero production coupling.

Also includes prior review fixes: enhanced client methods, expanded
unit tests (67→77), live integration test improvements, and updated
documentation.

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-02-28 16:08:04 +08:00
committed by GitHub
parent 9d48c42a20
commit 30d948711f
5 changed files with 625 additions and 232 deletions

View File

@@ -238,7 +238,7 @@ DeerFlow is model-agnostic — it works with any LLM that implements the OpenAI-
## Embedded Python Client ## Embedded Python Client
DeerFlow can be used as an embedded Python library without running the full HTTP services. The `DeerFlowClient` provides direct in-process access to all agent and Gateway capabilities: DeerFlow can be used as an embedded Python library without running the full HTTP services. The `DeerFlowClient` provides direct in-process access to all agent and Gateway capabilities, returning the same response schemas as the HTTP Gateway API:
```python ```python
from src.client import DeerFlowClient from src.client import DeerFlowClient
@@ -248,18 +248,19 @@ client = DeerFlowClient()
# Chat # Chat
response = client.chat("Analyze this paper for me", thread_id="my-thread") response = client.chat("Analyze this paper for me", thread_id="my-thread")
# Streaming # Streaming (LangGraph SSE protocol: values, messages-tuple, end)
for event in client.stream("hello"): for event in client.stream("hello"):
print(event.type, event.data) if event.type == "messages-tuple" and event.data.get("type") == "ai":
print(event.data["content"])
# Configuration & management # Configuration & management — returns Gateway-aligned dicts
print(client.list_models()) models = client.list_models() # {"models": [...]}
print(client.list_skills()) skills = client.list_skills() # {"skills": [...]}
client.update_skill("web-search", enabled=True) client.update_skill("web-search", enabled=True)
client.upload_files("thread-1", ["./report.pdf"]) client.upload_files("thread-1", ["./report.pdf"]) # {"success": True, "files": [...]}
``` ```
See `backend/src/client.py` for full API documentation. All dict-returning methods are validated against Gateway Pydantic response models in CI (`TestGatewayConformance`), ensuring the embedded client stays in sync with the HTTP API schemas. See `backend/src/client.py` for full API documentation.
## Documentation ## Documentation

View File

@@ -294,31 +294,36 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
### Embedded Client (`src/client.py`) ### Embedded Client (`src/client.py`)
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. `DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes.
**Architecture**: Imports the same `src/` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency. **Architecture**: Imports the same `src/` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
**Agent Conversation** (replaces LangGraph Server): **Agent Conversation** (replaces LangGraph Server):
- `chat(message, thread_id)` — synchronous, returns final text - `chat(message, thread_id)` — synchronous, returns final text
- `stream(message, thread_id)` — yields `StreamEvent` (message, tool_call, tool_result, title, done) - `stream(message, thread_id)` — yields `StreamEvent` aligned with LangGraph SSE protocol:
- `"values"` — full state snapshot (title, messages, artifacts)
- `"messages-tuple"` — per-message update (AI text, tool calls, tool results)
- `"end"` — stream finished
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent` - Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
- Supports `checkpointer` parameter for state persistence across turns - Supports `checkpointer` parameter for state persistence across turns
- Invocation pattern: `agent.stream(state, config, context, stream_mode="values")` - `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
**Gateway Equivalent Methods** (replaces Gateway API): **Gateway Equivalent Methods** (replaces Gateway API):
| Category | Methods | | Category | Methods | Return format |
|----------|---------| |----------|---------|---------------|
| Models | `list_models()`, `get_model(name)` | | Models | `list_models()`, `get_model(name)` | `{"models": [...]}`, `{name, display_name, ...}` |
| MCP | `get_mcp_config()`, `update_mcp_config(servers)` | | MCP | `get_mcp_config()`, `update_mcp_config(servers)` | `{"mcp_servers": {...}}` |
| Skills | `list_skills()`, `get_skill(name)`, `update_skill(name, enabled)`, `install_skill(path)` | | Skills | `list_skills()`, `get_skill(name)`, `update_skill(name, enabled)`, `install_skill(path)` | `{"skills": [...]}` |
| Memory | `get_memory()`, `reload_memory()`, `get_memory_config()`, `get_memory_status()` | | Memory | `get_memory()`, `reload_memory()`, `get_memory_config()`, `get_memory_status()` | dict |
| Uploads | `upload_files(thread_id, files)`, `list_uploads(thread_id)`, `delete_upload(thread_id, filename)` | | Uploads | `upload_files(thread_id, files)`, `list_uploads(thread_id)`, `delete_upload(thread_id, filename)` | `{"success": true, "files": [...]}`, `{"files": [...], "count": N}` |
| Artifacts | `get_artifact(thread_id, path)``(bytes, mime_type)` | | Artifacts | `get_artifact(thread_id, path)``(bytes, mime_type)` | tuple |
**Key difference from Gateway**: Upload accepts local `Path` objects instead of HTTP `UploadFile`. Artifact returns `(bytes, mime_type)` instead of HTTP Response. **Key difference from Gateway**: Upload accepts local `Path` objects instead of HTTP `UploadFile`. Artifact returns `(bytes, mime_type)` instead of HTTP Response. `update_mcp_config()` and `update_skill()` automatically invalidate the cached agent.
**Tests**: `tests/test_client.py` (45 unit tests) **Tests**: `tests/test_client.py` (77 unit tests including `TestGatewayConformance`), `tests/test_client_live.py` (live integration tests, requires config.yaml)
**Gateway Conformance Tests** (`TestGatewayConformance`): Validate that every dict-returning client method conforms to the corresponding Gateway Pydantic response model. Each test parses the client output through the Gateway model — if Gateway adds a required field that the client doesn't provide, Pydantic raises `ValidationError` and CI catches the drift. Covers: `ModelsListResponse`, `ModelResponse`, `SkillsListResponse`, `SkillResponse`, `SkillInstallResponse`, `McpConfigResponse`, `UploadResponse`, `MemoryConfigResponse`, `MemoryStatusResponse`.
## Development Workflow ## Development Workflow

View File

@@ -30,7 +30,7 @@ from pathlib import Path
from typing import Any from typing import Any
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from src.agents.lead_agent.agent import _build_middlewares from src.agents.lead_agent.agent import _build_middlewares
@@ -38,8 +38,8 @@ from src.agents.lead_agent.prompt import apply_prompt_template
from src.agents.thread_state import ThreadState from src.agents.thread_state import ThreadState
from src.config.app_config import get_app_config, reload_app_config from src.config.app_config import get_app_config, reload_app_config
from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from src.models import create_chat_model
from src.config.paths import get_paths from src.config.paths import get_paths
from src.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -48,8 +48,13 @@ logger = logging.getLogger(__name__)
class StreamEvent: class StreamEvent:
"""A single event from the streaming agent response. """A single event from the streaming agent response.
Event types align with the LangGraph SSE protocol:
- ``"values"``: Full state snapshot (title, messages, artifacts).
- ``"messages-tuple"``: Per-message update (AI text, tool calls, tool results).
- ``"end"``: Stream finished.
Attributes: Attributes:
type: Event type"message", "tool_call", "tool_result", "title", or "done". type: Event type.
data: Event payload. Contents vary by type. data: Event payload. Contents vary by type.
""" """
@@ -214,6 +219,28 @@ class DeerFlowClient:
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
@staticmethod
def _serialize_message(msg) -> dict:
"""Serialize a LangChain message to a plain dict for values events."""
if isinstance(msg, AIMessage):
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
if msg.tool_calls:
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls]
return d
if isinstance(msg, ToolMessage):
return {
"type": "tool",
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None),
}
if isinstance(msg, HumanMessage):
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if isinstance(msg, SystemMessage):
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
@staticmethod @staticmethod
def _extract_text(content) -> str: def _extract_text(content) -> str:
"""Extract plain text from AIMessage content (str or list of blocks).""" """Extract plain text from AIMessage content (str or list of blocks)."""
@@ -246,6 +273,10 @@ class DeerFlowClient:
finishes its turn. A ``checkpointer`` must be provided at init time finishes its turn. A ``checkpointer`` must be provided at init time
for multi-turn context to be preserved across calls. for multi-turn context to be preserved across calls.
Event types align with the LangGraph SSE protocol so that
consumers can switch between HTTP streaming and embedded mode
without changing their event-handling logic.
Args: Args:
message: User message text. message: User message text.
thread_id: Thread ID for conversation context. Auto-generated if None. thread_id: Thread ID for conversation context. Auto-generated if None.
@@ -254,11 +285,11 @@ class DeerFlowClient:
Yields: Yields:
StreamEvent with one of: StreamEvent with one of:
- type="message" data={"content": str} - type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
- type="tool_call" data={"name": str, "args": dict, "id": str} - type="messages-tuple" data={"type": "ai", "content": str, "id": str}
- type="tool_result" data={"name": str, "content": str, "tool_call_id": str} - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="title" data={"title": str} - type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="done" data={} - type="end" data={}
""" """
if thread_id is None: if thread_id is None:
thread_id = str(uuid.uuid4()) thread_id = str(uuid.uuid4())
@@ -270,7 +301,6 @@ class DeerFlowClient:
context = {"thread_id": thread_id} context = {"thread_id": thread_id}
seen_ids: set[str] = set() seen_ids: set[str] = set()
last_title: str | None = None
for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"): for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"):
messages = chunk.get("messages", []) messages = chunk.get("messages", [])
@@ -284,41 +314,57 @@ class DeerFlowClient:
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
if msg.tool_calls: if msg.tool_calls:
for tc in msg.tool_calls:
yield StreamEvent( yield StreamEvent(
type="tool_call", type="messages-tuple",
data={"name": tc["name"], "args": tc["args"], "id": tc.get("id")}, data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": [
{"name": tc["name"], "args": tc["args"], "id": tc.get("id")}
for tc in msg.tool_calls
],
},
) )
text = self._extract_text(msg.content) text = self._extract_text(msg.content)
if text: if text:
yield StreamEvent(type="message", data={"content": text}) yield StreamEvent(
type="messages-tuple",
data={"type": "ai", "content": text, "id": msg_id},
)
elif isinstance(msg, ToolMessage): elif isinstance(msg, ToolMessage):
yield StreamEvent( yield StreamEvent(
type="tool_result", type="messages-tuple",
data={ data={
"name": getattr(msg, "name", None), "type": "tool",
"content": msg.content if isinstance(msg.content, str) else str(msg.content), "content": msg.content if isinstance(msg.content, str) else str(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None), "tool_call_id": getattr(msg, "tool_call_id", None),
"id": msg_id,
}, },
) )
# Title changes # Emit a values event for each state snapshot
title = chunk.get("title") yield StreamEvent(
if title and title != last_title: type="values",
last_title = title data={
yield StreamEvent(type="title", data={"title": title}) "title": chunk.get("title"),
"messages": [self._serialize_message(m) for m in messages],
"artifacts": chunk.get("artifacts", []),
},
)
yield StreamEvent(type="done", data={}) yield StreamEvent(type="end", data={})
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str: def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
"""Send a message and return the final text response. """Send a message and return the final text response.
Convenience wrapper around :meth:`stream` that returns only the Convenience wrapper around :meth:`stream` that returns only the
**last** ``message`` event's text. If the agent emits multiple **last** AI text from ``messages-tuple`` events. If the agent emits
message segments in one turn, intermediate segments are discarded. multiple text segments in one turn, intermediate segments are
Use :meth:`stream` directly to capture all events. discarded. Use :meth:`stream` directly to capture all events.
Args: Args:
message: User message text. message: User message text.
@@ -330,42 +376,59 @@ class DeerFlowClient:
""" """
last_text = "" last_text = ""
for event in self.stream(message, thread_id=thread_id, **kwargs): for event in self.stream(message, thread_id=thread_id, **kwargs):
if event.type == "message": if event.type == "messages-tuple" and event.data.get("type") == "ai":
last_text = event.data.get("content", "") content = event.data.get("content", "")
if content:
last_text = content
return last_text return last_text
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API — configuration queries # Public API — configuration queries
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def list_models(self) -> list[dict]: def list_models(self) -> dict:
"""List available models from configuration. """List available models from configuration.
Returns: Returns:
List of model config dicts. Dict with "models" key containing list of model info dicts,
matching the Gateway API ``ModelsListResponse`` schema.
""" """
return [model.model_dump() for model in self._app_config.models] return {
"models": [
{
"name": model.name,
"display_name": getattr(model, "display_name", None),
"description": getattr(model, "description", None),
"supports_thinking": getattr(model, "supports_thinking", False),
}
for model in self._app_config.models
]
}
def list_skills(self, enabled_only: bool = False) -> list[dict]: def list_skills(self, enabled_only: bool = False) -> dict:
"""List available skills. """List available skills.
Args: Args:
enabled_only: If True, only return enabled skills. enabled_only: If True, only return enabled skills.
Returns: Returns:
List of skill info dicts with name, description, category, enabled. Dict with "skills" key containing list of skill info dicts,
matching the Gateway API ``SkillsListResponse`` schema.
""" """
from src.skills.loader import load_skills from src.skills.loader import load_skills
return [ return {
"skills": [
{ {
"name": s.name, "name": s.name,
"description": s.description, "description": s.description,
"license": s.license,
"category": s.category, "category": s.category,
"enabled": s.enabled, "enabled": s.enabled,
} }
for s in load_skills(enabled_only=enabled_only) for s in load_skills(enabled_only=enabled_only)
] ]
}
def get_memory(self) -> dict: def get_memory(self) -> dict:
"""Get current memory data. """Get current memory data.
@@ -384,25 +447,34 @@ class DeerFlowClient:
name: Model name. name: Model name.
Returns: Returns:
Model config dict, or None if not found. Model info dict matching the Gateway API ``ModelResponse``
schema, or None if not found.
""" """
model = self._app_config.get_model_config(name) model = self._app_config.get_model_config(name)
return model.model_dump() if model is not None else None if model is None:
return None
return {
"name": model.name,
"display_name": getattr(model, "display_name", None),
"description": getattr(model, "description", None),
"supports_thinking": getattr(model, "supports_thinking", False),
}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API — MCP configuration # Public API — MCP configuration
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def get_mcp_config(self) -> dict[str, dict]: def get_mcp_config(self) -> dict:
"""Get MCP server configurations. """Get MCP server configurations.
Returns: Returns:
Dict mapping server name to its config dict. Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema.
""" """
config = get_extensions_config() config = get_extensions_config()
return {name: server.model_dump() for name, server in config.mcp_servers.items()} return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict[str, dict]: def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations. """Update MCP server configurations.
Writes to extensions_config.json and reloads the cache. Writes to extensions_config.json and reloads the cache.
@@ -412,7 +484,8 @@ class DeerFlowClient:
Each value should contain keys like enabled, type, command, args, env, url, etc. Each value should contain keys like enabled, type, command, args, env, url, etc.
Returns: Returns:
The updated MCP config. Dict with "mcp_servers" key, matching the Gateway API
``McpConfigResponse`` schema.
Raises: Raises:
OSError: If the config file cannot be written. OSError: If the config file cannot be written.
@@ -421,7 +494,7 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError( raise FileNotFoundError(
"Cannot locate extensions_config.json. " "Cannot locate extensions_config.json. "
"Pass config_path to DeerFlowClient or set DEER_FLOW_HOME." "Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root."
) )
current_config = get_extensions_config() current_config = get_extensions_config()
@@ -435,7 +508,7 @@ class DeerFlowClient:
self._agent = None self._agent = None
reloaded = reload_extensions_config() reloaded = reload_extensions_config()
return {name: server.model_dump() for name, server in reloaded.mcp_servers.items()} return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API — skills management # Public API — skills management
@@ -488,7 +561,7 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError( raise FileNotFoundError(
"Cannot locate extensions_config.json. " "Cannot locate extensions_config.json. "
"Pass config_path to DeerFlowClient or set DEER_FLOW_HOME." "Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root."
) )
extensions_config = get_extensions_config() extensions_config = get_extensions_config()
@@ -634,7 +707,7 @@ class DeerFlowClient:
base.mkdir(parents=True, exist_ok=True) base.mkdir(parents=True, exist_ok=True)
return base return base
def upload_files(self, thread_id: str, files: list[str | Path]) -> list[dict]: def upload_files(self, thread_id: str, files: list[str | Path]) -> dict:
"""Upload local files into a thread's uploads directory. """Upload local files into a thread's uploads directory.
For PDF, PPT, Excel, and Word files, they are also converted to Markdown. For PDF, PPT, Excel, and Word files, they are also converted to Markdown.
@@ -644,7 +717,8 @@ class DeerFlowClient:
files: List of local file paths to upload. files: List of local file paths to upload.
Returns: Returns:
List of file info dicts (filename, size, path, virtual_path). Dict with success, files, message — matching the Gateway API
``UploadResponse`` schema.
Raises: Raises:
FileNotFoundError: If any file does not exist. FileNotFoundError: If any file does not exist.
@@ -660,7 +734,7 @@ class DeerFlowClient:
resolved_files.append(p) resolved_files.append(p)
uploads_dir = self._get_uploads_dir(thread_id) uploads_dir = self._get_uploads_dir(thread_id)
results: list[dict] = [] uploaded_files: list[dict] = []
for src_path in resolved_files: for src_path in resolved_files:
@@ -669,9 +743,10 @@ class DeerFlowClient:
info: dict[str, Any] = { info: dict[str, Any] = {
"filename": src_path.name, "filename": src_path.name,
"size": dest.stat().st_size, "size": str(dest.stat().st_size),
"path": str(dest), "path": str(dest),
"virtual_path": f"/mnt/user-data/uploads/{src_path.name}", "virtual_path": f"/mnt/user-data/uploads/{src_path.name}",
"artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{src_path.name}",
} }
if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS: if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS:
@@ -690,23 +765,29 @@ class DeerFlowClient:
if md_path is not None: if md_path is not None:
info["markdown_file"] = md_path.name info["markdown_file"] = md_path.name
info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}" info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}"
info["markdown_artifact_url"] = f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{md_path.name}"
results.append(info) uploaded_files.append(info)
return results return {
"success": True,
"files": uploaded_files,
"message": f"Successfully uploaded {len(uploaded_files)} file(s)",
}
def list_uploads(self, thread_id: str) -> list[dict]: def list_uploads(self, thread_id: str) -> dict:
"""List files in a thread's uploads directory. """List files in a thread's uploads directory.
Args: Args:
thread_id: Thread ID. thread_id: Thread ID.
Returns: Returns:
List of file info dicts. Dict with "files" and "count" keys, matching the Gateway API
``list_uploaded_files`` response.
""" """
uploads_dir = self._get_uploads_dir(thread_id) uploads_dir = self._get_uploads_dir(thread_id)
if not uploads_dir.exists(): if not uploads_dir.exists():
return [] return {"files": [], "count": 0}
files = [] files = []
for fp in sorted(uploads_dir.iterdir()): for fp in sorted(uploads_dir.iterdir()):
@@ -714,21 +795,26 @@ class DeerFlowClient:
stat = fp.stat() stat = fp.stat()
files.append({ files.append({
"filename": fp.name, "filename": fp.name,
"size": stat.st_size, "size": str(stat.st_size),
"path": str(fp), "path": str(fp),
"virtual_path": f"/mnt/user-data/uploads/{fp.name}", "virtual_path": f"/mnt/user-data/uploads/{fp.name}",
"artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{fp.name}",
"extension": fp.suffix, "extension": fp.suffix,
"modified": stat.st_mtime, "modified": stat.st_mtime,
}) })
return files return {"files": files, "count": len(files)}
def delete_upload(self, thread_id: str, filename: str) -> None: def delete_upload(self, thread_id: str, filename: str) -> dict:
"""Delete a file from a thread's uploads directory. """Delete a file from a thread's uploads directory.
Args: Args:
thread_id: Thread ID. thread_id: Thread ID.
filename: Filename to delete. filename: Filename to delete.
Returns:
Dict with success and message, matching the Gateway API
``delete_uploaded_file`` response.
Raises: Raises:
FileNotFoundError: If the file does not exist. FileNotFoundError: If the file does not exist.
PermissionError: If path traversal is detected. PermissionError: If path traversal is detected.
@@ -738,13 +824,14 @@ class DeerFlowClient:
try: try:
file_path.relative_to(uploads_dir.resolve()) file_path.relative_to(uploads_dir.resolve())
except ValueError: except ValueError as exc:
raise PermissionError("Access denied: path traversal detected") raise PermissionError("Access denied: path traversal detected") from exc
if not file_path.is_file(): if not file_path.is_file():
raise FileNotFoundError(f"File not found: {filename}") raise FileNotFoundError(f"File not found: {filename}")
file_path.unlink() file_path.unlink()
return {"success": True, "message": f"Deleted {filename}"}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API — artifacts # Public API — artifacts
@@ -775,8 +862,8 @@ class DeerFlowClient:
try: try:
actual.relative_to(base_dir.resolve()) actual.relative_to(base_dir.resolve())
except ValueError: except ValueError as exc:
raise PermissionError("Access denied: path traversal detected") raise PermissionError("Access denied: path traversal detected") from exc
if not actual.exists(): if not actual.exists():
raise FileNotFoundError(f"Artifact not found: {path}") raise FileNotFoundError(f"Artifact not found: {path}")
if not actual.is_file(): if not actual.is_file():

View File

@@ -10,6 +10,11 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401 from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401
from src.client import DeerFlowClient from src.client import DeerFlowClient
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
from src.gateway.routers.mcp import McpConfigResponse
from src.gateway.routers.models import ModelResponse, ModelsListResponse
from src.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
from src.gateway.routers.uploads import UploadResponse
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fixtures # Fixtures
@@ -81,14 +86,19 @@ class TestClientInit:
class TestConfigQueries: class TestConfigQueries:
def test_list_models(self, client): def test_list_models(self, client):
models = client.list_models() result = client.list_models()
assert len(models) == 1 assert "models" in result
assert models[0]["name"] == "test-model" assert len(result["models"]) == 1
assert result["models"][0]["name"] == "test-model"
# Verify Gateway-aligned fields are present
assert "display_name" in result["models"][0]
assert "supports_thinking" in result["models"][0]
def test_list_skills(self, client): def test_list_skills(self, client):
skill = MagicMock() skill = MagicMock()
skill.name = "web-search" skill.name = "web-search"
skill.description = "Search the web" skill.description = "Search the web"
skill.license = "MIT"
skill.category = "public" skill.category = "public"
skill.enabled = True skill.enabled = True
@@ -96,10 +106,12 @@ class TestConfigQueries:
result = client.list_skills() result = client.list_skills()
mock_load.assert_called_once_with(enabled_only=False) mock_load.assert_called_once_with(enabled_only=False)
assert len(result) == 1 assert "skills" in result
assert result[0] == { assert len(result["skills"]) == 1
assert result["skills"][0] == {
"name": "web-search", "name": "web-search",
"description": "Search the web", "description": "Search the web",
"license": "MIT",
"category": "public", "category": "public",
"enabled": True, "enabled": True,
} }
@@ -128,9 +140,24 @@ def _make_agent_mock(chunks: list[dict]):
return agent return agent
def _ai_events(events):
"""Filter messages-tuple events with type=ai and non-empty content."""
return [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
def _tool_call_events(events):
"""Filter messages-tuple events with type=ai and tool_calls."""
return [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data]
def _tool_result_events(events):
"""Filter messages-tuple events with type=tool."""
return [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"]
class TestStream: class TestStream:
def test_basic_message(self, client): def test_basic_message(self, client):
"""stream() emits message + done for a simple AI reply.""" """stream() emits messages-tuple + values + end for a simple AI reply."""
ai = AIMessage(content="Hello!", id="ai-1") ai = AIMessage(content="Hello!", id="ai-1")
chunks = [ chunks = [
{"messages": [HumanMessage(content="hi", id="h-1")]}, {"messages": [HumanMessage(content="hi", id="h-1")]},
@@ -145,13 +172,14 @@ class TestStream:
events = list(client.stream("hi", thread_id="t1")) events = list(client.stream("hi", thread_id="t1"))
types = [e.type for e in events] types = [e.type for e in events]
assert "message" in types assert "messages-tuple" in types
assert types[-1] == "done" assert "values" in types
msg_events = [e for e in events if e.type == "message"] assert types[-1] == "end"
msg_events = _ai_events(events)
assert msg_events[0].data["content"] == "Hello!" assert msg_events[0].data["content"] == "Hello!"
def test_tool_call_and_result(self, client): def test_tool_call_and_result(self, client):
"""stream() emits tool_call and tool_result events.""" """stream() emits messages-tuple events for tool calls and results."""
ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}]) ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}])
tool = ToolMessage(content="file.txt", id="tm-1", tool_call_id="tc-1", name="bash") tool = ToolMessage(content="file.txt", id="tm-1", tool_call_id="tc-1", name="bash")
ai2 = AIMessage(content="Here are the files.", id="ai-2") ai2 = AIMessage(content="Here are the files.", id="ai-2")
@@ -169,14 +197,13 @@ class TestStream:
): ):
events = list(client.stream("list files", thread_id="t2")) events = list(client.stream("list files", thread_id="t2"))
types = [e.type for e in events] assert len(_tool_call_events(events)) >= 1
assert "tool_call" in types assert len(_tool_result_events(events)) >= 1
assert "tool_result" in types assert len(_ai_events(events)) >= 1
assert "message" in types assert events[-1].type == "end"
assert types[-1] == "done"
def test_title_event(self, client): def test_values_event_with_title(self, client):
"""stream() emits title event when title appears in state.""" """stream() emits values event containing title when present in state."""
ai = AIMessage(content="ok", id="ai-1") ai = AIMessage(content="ok", id="ai-1")
chunks = [ chunks = [
{"messages": [HumanMessage(content="hi", id="h-1"), ai], "title": "Greeting"}, {"messages": [HumanMessage(content="hi", id="h-1"), ai], "title": "Greeting"},
@@ -189,9 +216,10 @@ class TestStream:
): ):
events = list(client.stream("hi", thread_id="t3")) events = list(client.stream("hi", thread_id="t3"))
title_events = [e for e in events if e.type == "title"] values_events = [e for e in events if e.type == "values"]
assert len(title_events) == 1 assert len(values_events) >= 1
assert title_events[0].data["title"] == "Greeting" assert values_events[-1].data["title"] == "Greeting"
assert "messages" in values_events[-1].data
def test_deduplication(self, client): def test_deduplication(self, client):
"""Messages with the same id are not emitted twice.""" """Messages with the same id are not emitted twice."""
@@ -208,7 +236,7 @@ class TestStream:
): ):
events = list(client.stream("hi", thread_id="t4")) events = list(client.stream("hi", thread_id="t4"))
msg_events = [e for e in events if e.type == "message"] msg_events = _ai_events(events)
assert len(msg_events) == 1 assert len(msg_events) == 1
def test_auto_thread_id(self, client): def test_auto_thread_id(self, client):
@@ -221,8 +249,8 @@ class TestStream:
): ):
events = list(client.stream("hi")) events = list(client.stream("hi"))
# Should not raise; done event proves it completed # Should not raise; end event proves it completed
assert events[-1].type == "done" assert events[-1].type == "end"
def test_list_content_blocks(self, client): def test_list_content_blocks(self, client):
"""stream() handles AIMessage with list-of-blocks content.""" """stream() handles AIMessage with list-of-blocks content."""
@@ -242,7 +270,7 @@ class TestStream:
): ):
events = list(client.stream("hi", thread_id="t5")) events = list(client.stream("hi", thread_id="t5"))
msg_events = [e for e in events if e.type == "message"] msg_events = _ai_events(events)
assert len(msg_events) == 1 assert len(msg_events) == 1
assert msg_events[0].data["content"] == "result" assert msg_events[0].data["content"] == "result"
@@ -347,11 +375,19 @@ class TestEnsureAgent:
class TestGetModel: class TestGetModel:
def test_found(self, client): def test_found(self, client):
model_cfg = MagicMock() model_cfg = MagicMock()
model_cfg.model_dump.return_value = {"name": "test-model"} model_cfg.name = "test-model"
model_cfg.display_name = "Test Model"
model_cfg.description = "A test model"
model_cfg.supports_thinking = True
client._app_config.get_model_config.return_value = model_cfg client._app_config.get_model_config.return_value = model_cfg
result = client.get_model("test-model") result = client.get_model("test-model")
assert result == {"name": "test-model"} assert result == {
"name": "test-model",
"display_name": "Test Model",
"description": "A test model",
"supports_thinking": True,
}
def test_not_found(self, client): def test_not_found(self, client):
client._app_config.get_model_config.return_value = None client._app_config.get_model_config.return_value = None
@@ -372,8 +408,9 @@ class TestMcpConfig:
with patch("src.client.get_extensions_config", return_value=ext_config): with patch("src.client.get_extensions_config", return_value=ext_config):
result = client.get_mcp_config() result = client.get_mcp_config()
assert "github" in result assert "mcp_servers" in result
assert result["github"]["enabled"] is True assert "github" in result["mcp_servers"]
assert result["mcp_servers"]["github"]["enabled"] is True
def test_update_mcp_config(self, client): def test_update_mcp_config(self, client):
# Set up current config with skills # Set up current config with skills
@@ -400,7 +437,8 @@ class TestMcpConfig:
): ):
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}}) result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
assert "new-server" in result assert "mcp_servers" in result
assert "new-server" in result["mcp_servers"]
assert client._agent is None # M2: agent invalidated assert client._agent is None # M2: agent invalidated
# Verify file was actually written # Verify file was actually written
@@ -578,8 +616,11 @@ class TestUploads:
with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir):
result = client.upload_files("thread-1", [src_file]) result = client.upload_files("thread-1", [src_file])
assert len(result) == 1 assert result["success"] is True
assert result[0]["filename"] == "test.txt" assert len(result["files"]) == 1
assert result["files"][0]["filename"] == "test.txt"
assert "artifact_url" in result["files"][0]
assert "message" in result
assert (uploads_dir / "test.txt").exists() assert (uploads_dir / "test.txt").exists()
def test_upload_files_not_found(self, client): def test_upload_files_not_found(self, client):
@@ -595,9 +636,13 @@ class TestUploads:
with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir):
result = client.list_uploads("thread-1") result = client.list_uploads("thread-1")
assert len(result) == 2 assert result["count"] == 2
names = {f["filename"] for f in result} assert len(result["files"]) == 2
names = {f["filename"] for f in result["files"]}
assert names == {"a.txt", "b.txt"} assert names == {"a.txt", "b.txt"}
# Verify artifact_url is present
for f in result["files"]:
assert "artifact_url" in f
def test_delete_upload(self, client): def test_delete_upload(self, client):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
@@ -605,8 +650,10 @@ class TestUploads:
(uploads_dir / "delete-me.txt").write_text("gone") (uploads_dir / "delete-me.txt").write_text("gone")
with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir):
client.delete_upload("thread-1", "delete-me.txt") result = client.delete_upload("thread-1", "delete-me.txt")
assert result["success"] is True
assert "delete-me.txt" in result["message"]
assert not (uploads_dir / "delete-me.txt").exists() assert not (uploads_dir / "delete-me.txt").exists()
def test_delete_upload_not_found(self, client): def test_delete_upload_not_found(self, client):
@@ -707,7 +754,7 @@ class TestScenarioMultiTurnConversation:
assert agent.stream.call_count == 2 assert agent.stream.call_count == 2
def test_stream_collects_all_event_types_across_turns(self, client): def test_stream_collects_all_event_types_across_turns(self, client):
"""A full turn with tool_call tool_result → message → title → done.""" """A full turn emits messages-tuple (tool_call, tool_result, ai text) + values + end."""
ai_tc = AIMessage(content="", id="ai-1", tool_calls=[ ai_tc = AIMessage(content="", id="ai-1", tool_calls=[
{"name": "web_search", "args": {"query": "LangGraph"}, "id": "tc-1"}, {"name": "web_search", "args": {"query": "LangGraph"}, "id": "tc-1"},
]) ])
@@ -727,23 +774,30 @@ class TestScenarioMultiTurnConversation:
): ):
events = list(client.stream("search", thread_id="t-full")) events = list(client.stream("search", thread_id="t-full"))
types = [e.type for e in events] # Verify expected event types
assert types == ["tool_call", "tool_result", "message", "title", "done"] types = set(e.type for e in events)
assert types == {"messages-tuple", "values", "end"}
assert events[-1].type == "end"
# Verify event data integrity # Verify tool_call data
tc_event = events[0] tc_events = _tool_call_events(events)
assert tc_event.data["name"] == "web_search" assert len(tc_events) == 1
assert tc_event.data["args"] == {"query": "LangGraph"} assert tc_events[0].data["tool_calls"][0]["name"] == "web_search"
assert tc_events[0].data["tool_calls"][0]["args"] == {"query": "LangGraph"}
tr_event = events[1] # Verify tool_result data
assert tr_event.data["tool_call_id"] == "tc-1" tr_events = _tool_result_events(events)
assert "LangGraph" in tr_event.data["content"] assert len(tr_events) == 1
assert tr_events[0].data["tool_call_id"] == "tc-1"
assert "LangGraph" in tr_events[0].data["content"]
msg_event = events[2] # Verify AI text
assert "framework" in msg_event.data["content"] msg_events = _ai_events(events)
assert any("framework" in e.data["content"] for e in msg_events)
title_event = events[3] # Verify values event contains title
assert title_event.data["title"] == "LangGraph Search" values_events = [e for e in events if e.type == "values"]
assert any(e.data.get("title") == "LangGraph Search" for e in values_events)
class TestScenarioToolChain: class TestScenarioToolChain:
@@ -776,16 +830,16 @@ class TestScenarioToolChain:
): ):
events = list(client.stream("list and save", thread_id="t-chain")) events = list(client.stream("list and save", thread_id="t-chain"))
tool_calls = [e for e in events if e.type == "tool_call"] tool_calls = _tool_call_events(events)
tool_results = [e for e in events if e.type == "tool_result"] tool_results = _tool_result_events(events)
messages = [e for e in events if e.type == "message"] messages = _ai_events(events)
assert len(tool_calls) == 2 assert len(tool_calls) == 2
assert tool_calls[0].data["name"] == "bash" assert tool_calls[0].data["tool_calls"][0]["name"] == "bash"
assert tool_calls[1].data["name"] == "write_file" assert tool_calls[1].data["tool_calls"][0]["name"] == "write_file"
assert len(tool_results) == 2 assert len(tool_results) == 2
assert len(messages) == 1 assert len(messages) == 1
assert events[-1].type == "done" assert events[-1].type == "end"
class TestScenarioFileLifecycle: class TestScenarioFileLifecycle:
@@ -804,25 +858,27 @@ class TestScenarioFileLifecycle:
with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir):
# Step 1: Upload # Step 1: Upload
uploaded = client.upload_files("t-lifecycle", [ result = client.upload_files("t-lifecycle", [
tmp_path / "report.txt", tmp_path / "report.txt",
tmp_path / "data.csv", tmp_path / "data.csv",
]) ])
assert len(uploaded) == 2 assert result["success"] is True
assert {f["filename"] for f in uploaded} == {"report.txt", "data.csv"} assert len(result["files"]) == 2
assert {f["filename"] for f in result["files"]} == {"report.txt", "data.csv"}
# Step 2: List # Step 2: List
files = client.list_uploads("t-lifecycle") listed = client.list_uploads("t-lifecycle")
assert len(files) == 2 assert listed["count"] == 2
assert all("virtual_path" in f for f in files) assert all("virtual_path" in f for f in listed["files"])
# Step 3: Delete one # Step 3: Delete one
client.delete_upload("t-lifecycle", "report.txt") del_result = client.delete_upload("t-lifecycle", "report.txt")
assert del_result["success"] is True
# Step 4: Verify deletion # Step 4: Verify deletion
files = client.list_uploads("t-lifecycle") listed = client.list_uploads("t-lifecycle")
assert len(files) == 1 assert listed["count"] == 1
assert files[0]["filename"] == "data.csv" assert listed["files"][0]["filename"] == "data.csv"
def test_upload_then_read_artifact(self, client): def test_upload_then_read_artifact(self, client):
"""Upload a file, simulate agent producing artifact, read it back.""" """Upload a file, simulate agent producing artifact, read it back."""
@@ -840,7 +896,7 @@ class TestScenarioFileLifecycle:
with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir):
uploaded = client.upload_files("t-artifact", [src_file]) uploaded = client.upload_files("t-artifact", [src_file])
assert len(uploaded) == 1 assert len(uploaded["files"]) == 1
# Simulate agent writing an artifact # Simulate agent writing an artifact
(outputs_dir / "analysis.json").write_text('{"result": "processed"}') (outputs_dir / "analysis.json").write_text('{"result": "processed"}')
@@ -862,13 +918,16 @@ class TestScenarioConfigManagement:
def test_model_and_skill_discovery(self, client): def test_model_and_skill_discovery(self, client):
"""List models → get specific model → list skills → get specific skill.""" """List models → get specific model → list skills → get specific skill."""
# List models # List models
models = client.list_models() result = client.list_models()
assert len(models) >= 1 assert len(result["models"]) >= 1
model_name = models[0]["name"] model_name = result["models"][0]["name"]
# Get specific model # Get specific model
model_cfg = MagicMock() model_cfg = MagicMock()
model_cfg.model_dump.return_value = {"name": model_name, "use": "langchain_openai:ChatOpenAI"} model_cfg.name = model_name
model_cfg.display_name = None
model_cfg.description = None
model_cfg.supports_thinking = False
client._app_config.get_model_config.return_value = model_cfg client._app_config.get_model_config.return_value = model_cfg
detail = client.get_model(model_name) detail = client.get_model(model_name)
assert detail["name"] == model_name assert detail["name"] == model_name
@@ -877,12 +936,13 @@ class TestScenarioConfigManagement:
skill = MagicMock() skill = MagicMock()
skill.name = "web-search" skill.name = "web-search"
skill.description = "Search the web" skill.description = "Search the web"
skill.license = "MIT"
skill.category = "public" skill.category = "public"
skill.enabled = True skill.enabled = True
with patch("src.skills.loader.load_skills", return_value=[skill]): with patch("src.skills.loader.load_skills", return_value=[skill]):
skills = client.list_skills() skills_result = client.list_skills()
assert len(skills) == 1 assert len(skills_result["skills"]) == 1
# Get specific skill # Get specific skill
with patch("src.skills.loader.load_skills", return_value=[skill]): with patch("src.skills.loader.load_skills", return_value=[skill]):
@@ -912,7 +972,7 @@ class TestScenarioConfigManagement:
patch("src.client.reload_extensions_config", return_value=reloaded_config), patch("src.client.reload_extensions_config", return_value=reloaded_config),
): ):
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}}) mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
assert "my-mcp" in mcp_result assert "my-mcp" in mcp_result["mcp_servers"]
assert client._agent is None # Agent invalidated assert client._agent is None # Agent invalidated
# --- Skill toggle --- # --- Skill toggle ---
@@ -1072,8 +1132,8 @@ class TestScenarioThreadIsolation:
files_a = client.list_uploads("thread-a") files_a = client.list_uploads("thread-a")
files_b = client.list_uploads("thread-b") files_b = client.list_uploads("thread-b")
assert len(files_a) == 1 assert files_a["count"] == 1
assert len(files_b) == 0 assert files_b["count"] == 0
def test_artifacts_isolated_per_thread(self, client): def test_artifacts_isolated_per_thread(self, client):
"""Artifacts in thread-A are not accessible from thread-B.""" """Artifacts in thread-A are not accessible from thread-B."""
@@ -1168,12 +1228,13 @@ class TestScenarioSkillInstallAndUse:
installed_skill = MagicMock() installed_skill = MagicMock()
installed_skill.name = "my-analyzer" installed_skill.name = "my-analyzer"
installed_skill.description = "Analyze code" installed_skill.description = "Analyze code"
installed_skill.license = "MIT"
installed_skill.category = "custom" installed_skill.category = "custom"
installed_skill.enabled = True installed_skill.enabled = True
with patch("src.skills.loader.load_skills", return_value=[installed_skill]): with patch("src.skills.loader.load_skills", return_value=[installed_skill]):
skills = client.list_skills() skills_result = client.list_skills()
assert any(s["name"] == "my-analyzer" for s in skills) assert any(s["name"] == "my-analyzer" for s in skills_result["skills"])
# Step 3: Disable it # Step 3: Disable it
disabled_skill = MagicMock() disabled_skill = MagicMock()
@@ -1204,7 +1265,7 @@ class TestScenarioEdgeCases:
"""Scenario: Edge cases and error boundaries in realistic workflows.""" """Scenario: Edge cases and error boundaries in realistic workflows."""
def test_empty_stream_response(self, client): def test_empty_stream_response(self, client):
"""Agent produces no messages — only done event.""" """Agent produces no messages — only values + end events."""
agent = _make_agent_mock([{"messages": []}]) agent = _make_agent_mock([{"messages": []}])
with ( with (
@@ -1213,8 +1274,10 @@ class TestScenarioEdgeCases:
): ):
events = list(client.stream("hi", thread_id="t-empty")) events = list(client.stream("hi", thread_id="t-empty"))
assert len(events) == 1 # values event (empty messages) + end
assert events[0].type == "done" assert len(events) == 2
assert events[0].type == "values"
assert events[-1].type == "end"
def test_chat_on_empty_response(self, client): def test_chat_on_empty_response(self, client):
"""chat() returns empty string for no-message response.""" """chat() returns empty string for no-message response."""
@@ -1229,12 +1292,12 @@ class TestScenarioEdgeCases:
assert result == "" assert result == ""
def test_multiple_title_changes(self, client): def test_multiple_title_changes(self, client):
"""Only distinct title changes produce events.""" """Title changes are carried in values events."""
ai = AIMessage(content="ok", id="ai-1") ai = AIMessage(content="ok", id="ai-1")
chunks = [ chunks = [
{"messages": [ai], "title": "First Title"}, {"messages": [ai], "title": "First Title"},
{"messages": [], "title": "First Title"}, # same — should NOT emit {"messages": [], "title": "First Title"}, # same title repeated
{"messages": [], "title": "Second Title"}, # different — should emit {"messages": [], "title": "Second Title"}, # different title
] ]
agent = _make_agent_mock(chunks) agent = _make_agent_mock(chunks)
@@ -1244,13 +1307,15 @@ class TestScenarioEdgeCases:
): ):
events = list(client.stream("hi", thread_id="t-titles")) events = list(client.stream("hi", thread_id="t-titles"))
title_events = [e for e in events if e.type == "title"] # Every chunk produces a values event with the title
assert len(title_events) == 2 values_events = [e for e in events if e.type == "values"]
assert title_events[0].data["title"] == "First Title" assert len(values_events) == 3
assert title_events[1].data["title"] == "Second Title" assert values_events[0].data["title"] == "First Title"
assert values_events[1].data["title"] == "First Title"
assert values_events[2].data["title"] == "Second Title"
def test_concurrent_tool_calls_in_single_message(self, client): def test_concurrent_tool_calls_in_single_message(self, client):
"""Agent produces multiple tool_calls in one AIMessage.""" """Agent produces multiple tool_calls in one AIMessage — emitted as single messages-tuple."""
ai = AIMessage(content="", id="ai-1", tool_calls=[ ai = AIMessage(content="", id="ai-1", tool_calls=[
{"name": "web_search", "args": {"q": "a"}, "id": "tc-1"}, {"name": "web_search", "args": {"q": "a"}, "id": "tc-1"},
{"name": "web_search", "args": {"q": "b"}, "id": "tc-2"}, {"name": "web_search", "args": {"q": "b"}, "id": "tc-2"},
@@ -1265,9 +1330,11 @@ class TestScenarioEdgeCases:
): ):
events = list(client.stream("do things", thread_id="t-parallel")) events = list(client.stream("do things", thread_id="t-parallel"))
tc_events = [e for e in events if e.type == "tool_call"] tc_events = _tool_call_events(events)
assert len(tc_events) == 3 assert len(tc_events) == 1 # One messages-tuple event for the AIMessage
assert {e.data["id"] for e in tc_events} == {"tc-1", "tc-2", "tc-3"} tool_calls = tc_events[0].data["tool_calls"]
assert len(tool_calls) == 3
assert {tc["id"] for tc in tool_calls} == {"tc-1", "tc-2", "tc-3"}
def test_upload_convertible_file_conversion_failure(self, client): def test_upload_convertible_file_conversion_failure(self, client):
"""Upload a .pdf file where conversion fails — file still uploaded, no markdown.""" """Upload a .pdf file where conversion fails — file still uploaded, no markdown."""
@@ -1284,9 +1351,223 @@ class TestScenarioEdgeCases:
patch("src.gateway.routers.uploads.CONVERTIBLE_EXTENSIONS", {".pdf"}), patch("src.gateway.routers.uploads.CONVERTIBLE_EXTENSIONS", {".pdf"}),
patch("src.gateway.routers.uploads.convert_file_to_markdown", side_effect=Exception("conversion failed")), patch("src.gateway.routers.uploads.convert_file_to_markdown", side_effect=Exception("conversion failed")),
): ):
results = client.upload_files("t-pdf-fail", [pdf_file]) result = client.upload_files("t-pdf-fail", [pdf_file])
assert len(results) == 1 assert result["success"] is True
assert results[0]["filename"] == "doc.pdf" assert len(result["files"]) == 1
assert "markdown_file" not in results[0] # Conversion failed gracefully assert result["files"][0]["filename"] == "doc.pdf"
assert "markdown_file" not in result["files"][0] # Conversion failed gracefully
assert (uploads_dir / "doc.pdf").exists() # File still uploaded assert (uploads_dir / "doc.pdf").exists() # File still uploaded
# ---------------------------------------------------------------------------
# Gateway conformance — validate client output against Gateway Pydantic models
# ---------------------------------------------------------------------------
class TestGatewayConformance:
"""Validate that DeerFlowClient return dicts conform to Gateway Pydantic response models.
Each test calls a client method, then parses the result through the
corresponding Gateway response model. If the client drifts (missing or
wrong-typed fields), Pydantic raises ``ValidationError`` and CI catches it.
"""
def test_list_models(self, mock_app_config):
model = MagicMock()
model.name = "test-model"
model.display_name = "Test Model"
model.description = "A test model"
model.supports_thinking = False
mock_app_config.models = [model]
with patch("src.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient()
result = client.list_models()
parsed = ModelsListResponse(**result)
assert len(parsed.models) == 1
assert parsed.models[0].name == "test-model"
def test_get_model(self, mock_app_config):
model = MagicMock()
model.name = "test-model"
model.display_name = "Test Model"
model.description = "A test model"
model.supports_thinking = True
mock_app_config.models = [model]
mock_app_config.get_model_config.return_value = model
with patch("src.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient()
result = client.get_model("test-model")
assert result is not None
parsed = ModelResponse(**result)
assert parsed.name == "test-model"
def test_list_skills(self, client):
skill = MagicMock()
skill.name = "web-search"
skill.description = "Search the web"
skill.license = "MIT"
skill.category = "public"
skill.enabled = True
with patch("src.skills.loader.load_skills", return_value=[skill]):
result = client.list_skills()
parsed = SkillsListResponse(**result)
assert len(parsed.skills) == 1
assert parsed.skills[0].name == "web-search"
def test_get_skill(self, client):
skill = MagicMock()
skill.name = "web-search"
skill.description = "Search the web"
skill.license = "MIT"
skill.category = "public"
skill.enabled = True
with patch("src.skills.loader.load_skills", return_value=[skill]):
result = client.get_skill("web-search")
assert result is not None
parsed = SkillResponse(**result)
assert parsed.name == "web-search"
def test_install_skill(self, client, tmp_path):
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\nname: my-skill\ndescription: A test skill\n---\nBody\n"
)
archive = tmp_path / "my-skill.skill"
with zipfile.ZipFile(archive, "w") as zf:
zf.write(skill_dir / "SKILL.md", "my-skill/SKILL.md")
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
with patch("src.skills.loader.get_skills_root_path", return_value=tmp_path):
result = client.install_skill(archive)
parsed = SkillInstallResponse(**result)
assert parsed.success is True
assert parsed.skill_name == "my-skill"
def test_get_mcp_config(self, client):
server = MagicMock()
server.model_dump.return_value = {
"enabled": True,
"type": "stdio",
"command": "npx",
"args": ["-y", "server"],
"env": {},
"url": None,
"headers": {},
"description": "test server",
}
ext_config = MagicMock()
ext_config.mcp_servers = {"test": server}
with patch("src.client.get_extensions_config", return_value=ext_config):
result = client.get_mcp_config()
parsed = McpConfigResponse(**result)
assert "test" in parsed.mcp_servers
def test_update_mcp_config(self, client, tmp_path):
server = MagicMock()
server.model_dump.return_value = {
"enabled": True,
"type": "stdio",
"command": "npx",
"args": [],
"env": {},
"url": None,
"headers": {},
"description": "",
}
ext_config = MagicMock()
ext_config.mcp_servers = {"srv": server}
ext_config.skills = {}
config_file = tmp_path / "extensions_config.json"
config_file.write_text("{}")
with (
patch("src.client.get_extensions_config", return_value=ext_config),
patch("src.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("src.client.reload_extensions_config", return_value=ext_config),
):
result = client.update_mcp_config({"srv": server.model_dump.return_value})
parsed = McpConfigResponse(**result)
assert "srv" in parsed.mcp_servers
def test_upload_files(self, client, tmp_path):
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
src_file = tmp_path / "hello.txt"
src_file.write_text("hello")
with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir):
result = client.upload_files("t-conform", [src_file])
parsed = UploadResponse(**result)
assert parsed.success is True
assert len(parsed.files) == 1
def test_get_memory_config(self, client):
mem_cfg = MagicMock()
mem_cfg.enabled = True
mem_cfg.storage_path = ".deer-flow/memory.json"
mem_cfg.debounce_seconds = 30
mem_cfg.max_facts = 100
mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
with patch("src.config.memory_config.get_memory_config", return_value=mem_cfg):
result = client.get_memory_config()
parsed = MemoryConfigResponse(**result)
assert parsed.enabled is True
assert parsed.max_facts == 100
def test_get_memory_status(self, client):
mem_cfg = MagicMock()
mem_cfg.enabled = True
mem_cfg.storage_path = ".deer-flow/memory.json"
mem_cfg.debounce_seconds = 30
mem_cfg.max_facts = 100
mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
memory_data = {
"version": "1.0",
"lastUpdated": "",
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
"topOfMind": {"summary": "", "updatedAt": ""},
},
"history": {
"recentMonths": {"summary": "", "updatedAt": ""},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""},
},
"facts": [],
}
with (
patch("src.config.memory_config.get_memory_config", return_value=mem_cfg),
patch("src.agents.memory.updater.get_memory_data", return_value=memory_data),
):
result = client.get_memory_status()
parsed = MemoryStatusResponse(**result)
assert parsed.config.enabled is True
assert parsed.data.version == "1.0"

View File

@@ -8,7 +8,6 @@ They are skipped in CI and must be run explicitly:
import json import json
import os import os
import tempfile
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -68,26 +67,27 @@ class TestLiveBasicChat:
# =========================================================================== # ===========================================================================
class TestLiveStreaming: class TestLiveStreaming:
def test_stream_yields_message_and_done(self, client): def test_stream_yields_messages_tuple_and_end(self, client):
"""stream() produces at least one message event and ends with done.""" """stream() produces at least one messages-tuple event and ends with end."""
events = list(client.stream("Say hi in one word.")) events = list(client.stream("Say hi in one word."))
types = [e.type for e in events] types = [e.type for e in events]
assert "message" in types, f"Expected 'message' event, got: {types}" assert "messages-tuple" in types, f"Expected 'messages-tuple' event, got: {types}"
assert types[-1] == "done" assert "values" in types, f"Expected 'values' event, got: {types}"
assert types[-1] == "end"
for e in events: for e in events:
assert isinstance(e, StreamEvent) assert isinstance(e, StreamEvent)
print(f" [{e.type}] {e.data}") print(f" [{e.type}] {e.data}")
def test_stream_message_content_nonempty(self, client): def test_stream_ai_content_nonempty(self, client):
"""Streamed message events contain non-empty content.""" """Streamed messages-tuple AI events contain non-empty content."""
messages = [ ai_messages = [
e for e in client.stream("What color is the sky? One word.") e for e in client.stream("What color is the sky? One word.")
if e.type == "message" if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")
] ]
assert len(messages) >= 1 assert len(ai_messages) >= 1
for m in messages: for m in ai_messages:
assert len(m.data.get("content", "")) > 0 assert len(m.data.get("content", "")) > 0
@@ -108,16 +108,18 @@ class TestLiveToolUse:
for e in events: for e in events:
print(f" [{e.type}] {e.data}") print(f" [{e.type}] {e.data}")
# Should have tool_call + tool_result + message # All message events are now messages-tuple
assert "tool_call" in types, f"Expected tool_call, got: {types}" mt_events = [e for e in events if e.type == "messages-tuple"]
assert "tool_result" in types, f"Expected tool_result, got: {types}" tc_events = [e for e in mt_events if e.data.get("type") == "ai" and "tool_calls" in e.data]
assert "message" in types tr_events = [e for e in mt_events if e.data.get("type") == "tool"]
ai_events = [e for e in mt_events if e.data.get("type") == "ai" and e.data.get("content")]
tc = next(e for e in events if e.type == "tool_call") assert len(tc_events) >= 1, f"Expected tool_call event, got types: {types}"
assert tc.data["name"] == "bash" assert len(tr_events) >= 1, f"Expected tool result event, got types: {types}"
assert len(ai_events) >= 1
tr = next(e for e in events if e.type == "tool_result") assert tc_events[0].data["tool_calls"][0]["name"] == "bash"
assert "LIVE_TEST_OK" in tr.data["content"] assert "LIVE_TEST_OK" in tr_events[0].data["content"]
def test_agent_uses_ls_tool(self, client): def test_agent_uses_ls_tool(self, client):
"""Agent uses ls tool to list a directory.""" """Agent uses ls tool to list a directory."""
@@ -129,9 +131,9 @@ class TestLiveToolUse:
types = [e.type for e in events] types = [e.type for e in events]
print(f" event types: {types}") print(f" event types: {types}")
assert "tool_call" in types tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data]
tc = next(e for e in events if e.type == "tool_call") assert len(tc_events) >= 1
assert tc.data["name"] == "ls" assert tc_events[0].data["tool_calls"][0]["name"] == "ls"
# =========================================================================== # ===========================================================================
@@ -153,18 +155,19 @@ class TestLiveMultiToolChain:
for e in events: for e in events:
print(f" [{e.type}] {e.data}") print(f" [{e.type}] {e.data}")
tool_calls = [e for e in events if e.type == "tool_call"] tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data]
tool_names = [tc.data["name"] for tc in tool_calls] tool_names = [tc.data["tool_calls"][0]["name"] for tc in tc_events]
assert "write_file" in tool_names, f"Expected write_file, got: {tool_names}" assert "write_file" in tool_names, f"Expected write_file, got: {tool_names}"
assert "read_file" in tool_names, f"Expected read_file, got: {tool_names}" assert "read_file" in tool_names, f"Expected read_file, got: {tool_names}"
# Final message should mention the content # Final AI message or tool result should mention the content
messages = [e for e in events if e.type == "message"] ai_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
final_text = messages[-1].data["content"] if messages else "" tr_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"]
final_text = ai_events[-1].data["content"] if ai_events else ""
assert "integration_test_content" in final_text.lower() or any( assert "integration_test_content" in final_text.lower() or any(
"integration_test_content" in e.data.get("content", "") "integration_test_content" in e.data.get("content", "")
for e in events if e.type == "tool_result" for e in tr_events
) )
@@ -184,30 +187,35 @@ class TestLiveFileUpload:
f2.write_text("content B") f2.write_text("content B")
# Upload # Upload
results = client.upload_files(thread_id, [f1, f2]) result = client.upload_files(thread_id, [f1, f2])
assert len(results) == 2 assert result["success"] is True
filenames = {r["filename"] for r in results} assert len(result["files"]) == 2
filenames = {r["filename"] for r in result["files"]}
assert filenames == {"test_upload_a.txt", "test_upload_b.txt"} assert filenames == {"test_upload_a.txt", "test_upload_b.txt"}
for r in results: for r in result["files"]:
assert r["size"] > 0 assert int(r["size"]) > 0
assert r["virtual_path"].startswith("/mnt/user-data/uploads/") assert r["virtual_path"].startswith("/mnt/user-data/uploads/")
assert "artifact_url" in r
print(f" uploaded: {filenames}") print(f" uploaded: {filenames}")
# List # List
listed = client.list_uploads(thread_id) listed = client.list_uploads(thread_id)
assert len(listed) == 2 assert listed["count"] == 2
print(f" listed: {[f['filename'] for f in listed]}") print(f" listed: {[f['filename'] for f in listed['files']]}")
# Delete one # Delete one
client.delete_upload(thread_id, "test_upload_a.txt") del_result = client.delete_upload(thread_id, "test_upload_a.txt")
assert del_result["success"] is True
remaining = client.list_uploads(thread_id) remaining = client.list_uploads(thread_id)
assert len(remaining) == 1 assert remaining["count"] == 1
assert remaining[0]["filename"] == "test_upload_b.txt" assert remaining["files"][0]["filename"] == "test_upload_b.txt"
print(f" after delete: {[f['filename'] for f in remaining]}") print(f" after delete: {[f['filename'] for f in remaining['files']]}")
# Delete the other # Delete the other
client.delete_upload(thread_id, "test_upload_b.txt") client.delete_upload(thread_id, "test_upload_b.txt")
assert client.list_uploads(thread_id) == [] empty = client.list_uploads(thread_id)
assert empty["count"] == 0
assert empty["files"] == []
def test_upload_nonexistent_file_raises(self, client): def test_upload_nonexistent_file_raises(self, client):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
@@ -221,10 +229,15 @@ class TestLiveFileUpload:
class TestLiveConfigQueries: class TestLiveConfigQueries:
def test_list_models_returns_ark(self, client): def test_list_models_returns_ark(self, client):
"""list_models() returns the configured ARK model.""" """list_models() returns the configured ARK model."""
models = client.list_models() result = client.list_models()
assert len(models) >= 1 assert "models" in result
names = [m["name"] for m in models] assert len(result["models"]) >= 1
names = [m["name"] for m in result["models"]]
assert "ark-model" in names assert "ark-model" in names
# Verify Gateway-aligned fields
for m in result["models"]:
assert "display_name" in m
assert "supports_thinking" in m
print(f" models: {names}") print(f" models: {names}")
def test_get_model_found(self, client): def test_get_model_found(self, client):
@@ -232,6 +245,8 @@ class TestLiveConfigQueries:
model = client.get_model("ark-model") model = client.get_model("ark-model")
assert model is not None assert model is not None
assert model["name"] == "ark-model" assert model["name"] == "ark-model"
assert "display_name" in model
assert "supports_thinking" in model
print(f" model detail: {model}") print(f" model detail: {model}")
def test_get_model_not_found(self, client): def test_get_model_not_found(self, client):
@@ -239,10 +254,11 @@ class TestLiveConfigQueries:
def test_list_skills(self, client): def test_list_skills(self, client):
"""list_skills() runs without error.""" """list_skills() runs without error."""
skills = client.list_skills() result = client.list_skills()
assert isinstance(skills, list) assert "skills" in result
print(f" skills count: {len(skills)}") assert isinstance(result["skills"], list)
for s in skills[:3]: print(f" skills count: {len(result['skills'])}")
for s in result["skills"][:3]:
print(f" - {s['name']}: {s['enabled']}") print(f" - {s['name']}: {s['enabled']}")
@@ -264,8 +280,11 @@ class TestLiveArtifact:
)) ))
# Verify write happened # Verify write happened
tool_calls = [e for e in events if e.type == "tool_call"] tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data]
assert any(tc.data["name"] == "write_file" for tc in tool_calls) assert any(
any(tc["name"] == "write_file" for tc in e.data["tool_calls"])
for e in tc_events
)
# Read artifact # Read artifact
content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json") content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json")