mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-18 20:14:44 +08:00
test: add Gateway conformance tests for DeerFlowClient (#931)
Validate that all dict-returning client methods conform to Gateway Pydantic response models (ModelsListResponse, ModelResponse, SkillsListResponse, SkillResponse, SkillInstallResponse, McpConfigResponse, UploadResponse, MemoryConfigResponse, MemoryStatusResponse). Pydantic ValidationError in CI catches schema drift between client and Gateway with zero production coupling. Also includes prior review fixes: enhanced client methods, expanded unit tests (67→77), live integration test improvements, and updated documentation. Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
@@ -30,7 +30,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.agents.lead_agent.agent import _build_middlewares
|
||||
@@ -38,8 +38,8 @@ from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.config.app_config import get_app_config, reload_app_config
|
||||
from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from src.models import create_chat_model
|
||||
from src.config.paths import get_paths
|
||||
from src.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,8 +48,13 @@ logger = logging.getLogger(__name__)
|
||||
class StreamEvent:
|
||||
"""A single event from the streaming agent response.
|
||||
|
||||
Event types align with the LangGraph SSE protocol:
|
||||
- ``"values"``: Full state snapshot (title, messages, artifacts).
|
||||
- ``"messages-tuple"``: Per-message update (AI text, tool calls, tool results).
|
||||
- ``"end"``: Stream finished.
|
||||
|
||||
Attributes:
|
||||
type: Event type — "message", "tool_call", "tool_result", "title", or "done".
|
||||
type: Event type.
|
||||
data: Event payload. Contents vary by type.
|
||||
"""
|
||||
|
||||
@@ -214,6 +219,28 @@ class DeerFlowClient:
|
||||
|
||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_message(msg) -> dict:
|
||||
"""Serialize a LangChain message to a plain dict for values events."""
|
||||
if isinstance(msg, AIMessage):
|
||||
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if msg.tool_calls:
|
||||
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls]
|
||||
return d
|
||||
if isinstance(msg, ToolMessage):
|
||||
return {
|
||||
"type": "tool",
|
||||
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": getattr(msg, "id", None),
|
||||
}
|
||||
if isinstance(msg, HumanMessage):
|
||||
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if isinstance(msg, SystemMessage):
|
||||
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(content) -> str:
|
||||
"""Extract plain text from AIMessage content (str or list of blocks)."""
|
||||
@@ -246,6 +273,10 @@ class DeerFlowClient:
|
||||
finishes its turn. A ``checkpointer`` must be provided at init time
|
||||
for multi-turn context to be preserved across calls.
|
||||
|
||||
Event types align with the LangGraph SSE protocol so that
|
||||
consumers can switch between HTTP streaming and embedded mode
|
||||
without changing their event-handling logic.
|
||||
|
||||
Args:
|
||||
message: User message text.
|
||||
thread_id: Thread ID for conversation context. Auto-generated if None.
|
||||
@@ -254,11 +285,11 @@ class DeerFlowClient:
|
||||
|
||||
Yields:
|
||||
StreamEvent with one of:
|
||||
- type="message" data={"content": str}
|
||||
- type="tool_call" data={"name": str, "args": dict, "id": str}
|
||||
- type="tool_result" data={"name": str, "content": str, "tool_call_id": str}
|
||||
- type="title" data={"title": str}
|
||||
- type="done" data={}
|
||||
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||
- type="end" data={}
|
||||
"""
|
||||
if thread_id is None:
|
||||
thread_id = str(uuid.uuid4())
|
||||
@@ -270,7 +301,6 @@ class DeerFlowClient:
|
||||
context = {"thread_id": thread_id}
|
||||
|
||||
seen_ids: set[str] = set()
|
||||
last_title: str | None = None
|
||||
|
||||
for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"):
|
||||
messages = chunk.get("messages", [])
|
||||
@@ -284,41 +314,57 @@ class DeerFlowClient:
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
yield StreamEvent(
|
||||
type="tool_call",
|
||||
data={"name": tc["name"], "args": tc["args"], "id": tc.get("id")},
|
||||
)
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": [
|
||||
{"name": tc["name"], "args": tc["args"], "id": tc.get("id")}
|
||||
for tc in msg.tool_calls
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
text = self._extract_text(msg.content)
|
||||
if text:
|
||||
yield StreamEvent(type="message", data={"content": text})
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={"type": "ai", "content": text, "id": msg_id},
|
||||
)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
yield StreamEvent(
|
||||
type="tool_result",
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"name": getattr(msg, "name", None),
|
||||
"type": "tool",
|
||||
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": msg_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Title changes
|
||||
title = chunk.get("title")
|
||||
if title and title != last_title:
|
||||
last_title = title
|
||||
yield StreamEvent(type="title", data={"title": title})
|
||||
# Emit a values event for each state snapshot
|
||||
yield StreamEvent(
|
||||
type="values",
|
||||
data={
|
||||
"title": chunk.get("title"),
|
||||
"messages": [self._serialize_message(m) for m in messages],
|
||||
"artifacts": chunk.get("artifacts", []),
|
||||
},
|
||||
)
|
||||
|
||||
yield StreamEvent(type="done", data={})
|
||||
yield StreamEvent(type="end", data={})
|
||||
|
||||
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
|
||||
"""Send a message and return the final text response.
|
||||
|
||||
Convenience wrapper around :meth:`stream` that returns only the
|
||||
**last** ``message`` event's text. If the agent emits multiple
|
||||
message segments in one turn, intermediate segments are discarded.
|
||||
Use :meth:`stream` directly to capture all events.
|
||||
**last** AI text from ``messages-tuple`` events. If the agent emits
|
||||
multiple text segments in one turn, intermediate segments are
|
||||
discarded. Use :meth:`stream` directly to capture all events.
|
||||
|
||||
Args:
|
||||
message: User message text.
|
||||
@@ -330,42 +376,59 @@ class DeerFlowClient:
|
||||
"""
|
||||
last_text = ""
|
||||
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
||||
if event.type == "message":
|
||||
last_text = event.data.get("content", "")
|
||||
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
||||
content = event.data.get("content", "")
|
||||
if content:
|
||||
last_text = content
|
||||
return last_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — configuration queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_models(self) -> list[dict]:
|
||||
def list_models(self) -> dict:
|
||||
"""List available models from configuration.
|
||||
|
||||
Returns:
|
||||
List of model config dicts.
|
||||
Dict with "models" key containing list of model info dicts,
|
||||
matching the Gateway API ``ModelsListResponse`` schema.
|
||||
"""
|
||||
return [model.model_dump() for model in self._app_config.models]
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"name": model.name,
|
||||
"display_name": getattr(model, "display_name", None),
|
||||
"description": getattr(model, "description", None),
|
||||
"supports_thinking": getattr(model, "supports_thinking", False),
|
||||
}
|
||||
for model in self._app_config.models
|
||||
]
|
||||
}
|
||||
|
||||
def list_skills(self, enabled_only: bool = False) -> list[dict]:
|
||||
def list_skills(self, enabled_only: bool = False) -> dict:
|
||||
"""List available skills.
|
||||
|
||||
Args:
|
||||
enabled_only: If True, only return enabled skills.
|
||||
|
||||
Returns:
|
||||
List of skill info dicts with name, description, category, enabled.
|
||||
Dict with "skills" key containing list of skill info dicts,
|
||||
matching the Gateway API ``SkillsListResponse`` schema.
|
||||
"""
|
||||
from src.skills.loader import load_skills
|
||||
|
||||
return [
|
||||
{
|
||||
"name": s.name,
|
||||
"description": s.description,
|
||||
"category": s.category,
|
||||
"enabled": s.enabled,
|
||||
}
|
||||
for s in load_skills(enabled_only=enabled_only)
|
||||
]
|
||||
return {
|
||||
"skills": [
|
||||
{
|
||||
"name": s.name,
|
||||
"description": s.description,
|
||||
"license": s.license,
|
||||
"category": s.category,
|
||||
"enabled": s.enabled,
|
||||
}
|
||||
for s in load_skills(enabled_only=enabled_only)
|
||||
]
|
||||
}
|
||||
|
||||
def get_memory(self) -> dict:
|
||||
"""Get current memory data.
|
||||
@@ -384,25 +447,34 @@ class DeerFlowClient:
|
||||
name: Model name.
|
||||
|
||||
Returns:
|
||||
Model config dict, or None if not found.
|
||||
Model info dict matching the Gateway API ``ModelResponse``
|
||||
schema, or None if not found.
|
||||
"""
|
||||
model = self._app_config.get_model_config(name)
|
||||
return model.model_dump() if model is not None else None
|
||||
if model is None:
|
||||
return None
|
||||
return {
|
||||
"name": model.name,
|
||||
"display_name": getattr(model, "display_name", None),
|
||||
"description": getattr(model, "description", None),
|
||||
"supports_thinking": getattr(model, "supports_thinking", False),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — MCP configuration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_mcp_config(self) -> dict[str, dict]:
|
||||
def get_mcp_config(self) -> dict:
|
||||
"""Get MCP server configurations.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to its config dict.
|
||||
Dict with "mcp_servers" key mapping server name to config,
|
||||
matching the Gateway API ``McpConfigResponse`` schema.
|
||||
"""
|
||||
config = get_extensions_config()
|
||||
return {name: server.model_dump() for name, server in config.mcp_servers.items()}
|
||||
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
|
||||
|
||||
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict[str, dict]:
|
||||
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
|
||||
"""Update MCP server configurations.
|
||||
|
||||
Writes to extensions_config.json and reloads the cache.
|
||||
@@ -412,7 +484,8 @@ class DeerFlowClient:
|
||||
Each value should contain keys like enabled, type, command, args, env, url, etc.
|
||||
|
||||
Returns:
|
||||
The updated MCP config.
|
||||
Dict with "mcp_servers" key, matching the Gateway API
|
||||
``McpConfigResponse`` schema.
|
||||
|
||||
Raises:
|
||||
OSError: If the config file cannot be written.
|
||||
@@ -421,7 +494,7 @@ class DeerFlowClient:
|
||||
if config_path is None:
|
||||
raise FileNotFoundError(
|
||||
"Cannot locate extensions_config.json. "
|
||||
"Pass config_path to DeerFlowClient or set DEER_FLOW_HOME."
|
||||
"Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root."
|
||||
)
|
||||
|
||||
current_config = get_extensions_config()
|
||||
@@ -435,7 +508,7 @@ class DeerFlowClient:
|
||||
|
||||
self._agent = None
|
||||
reloaded = reload_extensions_config()
|
||||
return {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}
|
||||
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — skills management
|
||||
@@ -488,7 +561,7 @@ class DeerFlowClient:
|
||||
if config_path is None:
|
||||
raise FileNotFoundError(
|
||||
"Cannot locate extensions_config.json. "
|
||||
"Pass config_path to DeerFlowClient or set DEER_FLOW_HOME."
|
||||
"Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root."
|
||||
)
|
||||
|
||||
extensions_config = get_extensions_config()
|
||||
@@ -634,7 +707,7 @@ class DeerFlowClient:
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base
|
||||
|
||||
def upload_files(self, thread_id: str, files: list[str | Path]) -> list[dict]:
|
||||
def upload_files(self, thread_id: str, files: list[str | Path]) -> dict:
|
||||
"""Upload local files into a thread's uploads directory.
|
||||
|
||||
For PDF, PPT, Excel, and Word files, they are also converted to Markdown.
|
||||
@@ -644,7 +717,8 @@ class DeerFlowClient:
|
||||
files: List of local file paths to upload.
|
||||
|
||||
Returns:
|
||||
List of file info dicts (filename, size, path, virtual_path).
|
||||
Dict with success, files, message — matching the Gateway API
|
||||
``UploadResponse`` schema.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If any file does not exist.
|
||||
@@ -660,7 +734,7 @@ class DeerFlowClient:
|
||||
resolved_files.append(p)
|
||||
|
||||
uploads_dir = self._get_uploads_dir(thread_id)
|
||||
results: list[dict] = []
|
||||
uploaded_files: list[dict] = []
|
||||
|
||||
for src_path in resolved_files:
|
||||
|
||||
@@ -669,9 +743,10 @@ class DeerFlowClient:
|
||||
|
||||
info: dict[str, Any] = {
|
||||
"filename": src_path.name,
|
||||
"size": dest.stat().st_size,
|
||||
"size": str(dest.stat().st_size),
|
||||
"path": str(dest),
|
||||
"virtual_path": f"/mnt/user-data/uploads/{src_path.name}",
|
||||
"artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{src_path.name}",
|
||||
}
|
||||
|
||||
if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS:
|
||||
@@ -690,23 +765,29 @@ class DeerFlowClient:
|
||||
if md_path is not None:
|
||||
info["markdown_file"] = md_path.name
|
||||
info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}"
|
||||
info["markdown_artifact_url"] = f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{md_path.name}"
|
||||
|
||||
results.append(info)
|
||||
uploaded_files.append(info)
|
||||
|
||||
return results
|
||||
return {
|
||||
"success": True,
|
||||
"files": uploaded_files,
|
||||
"message": f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||
}
|
||||
|
||||
def list_uploads(self, thread_id: str) -> list[dict]:
|
||||
def list_uploads(self, thread_id: str) -> dict:
|
||||
"""List files in a thread's uploads directory.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
|
||||
Returns:
|
||||
List of file info dicts.
|
||||
Dict with "files" and "count" keys, matching the Gateway API
|
||||
``list_uploaded_files`` response.
|
||||
"""
|
||||
uploads_dir = self._get_uploads_dir(thread_id)
|
||||
if not uploads_dir.exists():
|
||||
return []
|
||||
return {"files": [], "count": 0}
|
||||
|
||||
files = []
|
||||
for fp in sorted(uploads_dir.iterdir()):
|
||||
@@ -714,21 +795,26 @@ class DeerFlowClient:
|
||||
stat = fp.stat()
|
||||
files.append({
|
||||
"filename": fp.name,
|
||||
"size": stat.st_size,
|
||||
"size": str(stat.st_size),
|
||||
"path": str(fp),
|
||||
"virtual_path": f"/mnt/user-data/uploads/{fp.name}",
|
||||
"artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{fp.name}",
|
||||
"extension": fp.suffix,
|
||||
"modified": stat.st_mtime,
|
||||
})
|
||||
return files
|
||||
return {"files": files, "count": len(files)}
|
||||
|
||||
def delete_upload(self, thread_id: str, filename: str) -> None:
|
||||
def delete_upload(self, thread_id: str, filename: str) -> dict:
|
||||
"""Delete a file from a thread's uploads directory.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
filename: Filename to delete.
|
||||
|
||||
Returns:
|
||||
Dict with success and message, matching the Gateway API
|
||||
``delete_uploaded_file`` response.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
PermissionError: If path traversal is detected.
|
||||
@@ -738,13 +824,14 @@ class DeerFlowClient:
|
||||
|
||||
try:
|
||||
file_path.relative_to(uploads_dir.resolve())
|
||||
except ValueError:
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
except ValueError as exc:
|
||||
raise PermissionError("Access denied: path traversal detected") from exc
|
||||
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
file_path.unlink()
|
||||
return {"success": True, "message": f"Deleted {filename}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — artifacts
|
||||
@@ -775,8 +862,8 @@ class DeerFlowClient:
|
||||
|
||||
try:
|
||||
actual.relative_to(base_dir.resolve())
|
||||
except ValueError:
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
except ValueError as exc:
|
||||
raise PermissionError("Access denied: path traversal detected") from exc
|
||||
if not actual.exists():
|
||||
raise FileNotFoundError(f"Artifact not found: {path}")
|
||||
if not actual.is_file():
|
||||
|
||||
Reference in New Issue
Block a user