feat: add DeerFlowClient for embedded programmatic access (#926)

Add `DeerFlowClient` class that provides direct in-process access to
DeerFlow's agent and Gateway capabilities without requiring LangGraph
Server or Gateway API processes. This enables users to import and use
DeerFlow as a Python library.

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-02-28 14:38:15 +08:00
committed by GitHub
parent 5ad8a657f4
commit 9d48c42a20
7 changed files with 2450 additions and 2 deletions

1
.gitignore vendored
View File

@@ -36,6 +36,7 @@ coverage/
.claude/
skills/custom/*
logs/
log/
# Local git hooks (keep only on this machine, do not push)
.githooks/

View File

@@ -236,6 +236,31 @@ DeerFlow is model-agnostic — it works with any LLM that implements the OpenAI-
- **Multimodal inputs** for image understanding and video comprehension
- **Strong tool-use** for reliable function calling and structured outputs
## Embedded Python Client
DeerFlow can be used as an embedded Python library without running the full HTTP services. The `DeerFlowClient` provides direct in-process access to all agent and Gateway capabilities:
```python
from src.client import DeerFlowClient
client = DeerFlowClient()
# Chat
response = client.chat("Analyze this paper for me", thread_id="my-thread")
# Streaming
for event in client.stream("hello"):
print(event.type, event.data)
# Configuration & management
print(client.list_models())
print(client.list_skills())
client.update_skill("web-search", enabled=True)
client.upload_files("thread-1", ["./report.pdf"])
```
See `backend/src/client.py` for full API documentation.
## Documentation
- [Contributing Guide](CONTRIBUTING.md) - Development environment setup and workflow

3
backend/.gitignore vendored
View File

@@ -11,6 +11,9 @@ wheels/
agent_history.gif
static/browser_history/*.gif
log/
log/*
# Virtual environments
.venv
venv/

View File

@@ -47,7 +47,8 @@ deer-flow/
│ │ ├── config/ # Configuration system (app, model, sandbox, tool, etc.)
│ │ ├── community/ # Community tools (tavily, jina_ai, firecrawl, image_search, aio_sandbox)
│ │ ├── reflection/ # Dynamic module loading (resolve_variable, resolve_class)
│ │ ── utils/ # Utilities (network, readability)
│ │ ── utils/ # Utilities (network, readability)
│ │ └── client.py # Embedded Python client (DeerFlowClient)
│ ├── tests/ # Test suite
│ └── docs/ # Documentation
├── frontend/ # Next.js frontend application
@@ -289,7 +290,35 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
- `mcpServers` - Map of server name → config (enabled, type, command, args, env, url, headers, description)
- `skills` - Map of skill name → state (enabled)
Both can be modified at runtime via Gateway API endpoints.
Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` methods.
### Embedded Client (`src/client.py`)
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services.
**Architecture**: Imports the same `src/` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
**Agent Conversation** (replaces LangGraph Server):
- `chat(message, thread_id)` — synchronous, returns final text
- `stream(message, thread_id)` — yields `StreamEvent` (message, tool_call, tool_result, title, done)
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
- Supports `checkpointer` parameter for state persistence across turns
- Invocation pattern: `agent.stream(state, config, context, stream_mode="values")`
**Gateway Equivalent Methods** (replaces Gateway API):
| Category | Methods |
|----------|---------|
| Models | `list_models()`, `get_model(name)` |
| MCP | `get_mcp_config()`, `update_mcp_config(servers)` |
| Skills | `list_skills()`, `get_skill(name)`, `update_skill(name, enabled)`, `install_skill(path)` |
| Memory | `get_memory()`, `reload_memory()`, `get_memory_config()`, `get_memory_status()` |
| Uploads | `upload_files(thread_id, files)`, `list_uploads(thread_id)`, `delete_upload(thread_id, filename)` |
| Artifacts | `get_artifact(thread_id, path)``(bytes, mime_type)` |
**Key difference from Gateway**: Upload accepts local `Path` objects instead of HTTP `UploadFile`. Artifact returns `(bytes, mime_type)` instead of HTTP Response.
**Tests**: `tests/test_client.py` (45 unit tests)
## Development Workflow

786
backend/src/client.py Normal file
View File

@@ -0,0 +1,786 @@
"""DeerFlowClient — Embedded Python client for DeerFlow agent system.
Provides direct programmatic access to DeerFlow's agent capabilities
without requiring LangGraph Server or Gateway API processes.
Usage:
from src.client import DeerFlowClient
client = DeerFlowClient()
response = client.chat("Analyze this paper for me", thread_id="my-thread")
print(response)
# Streaming
for event in client.stream("hello"):
print(event)
"""
import asyncio
import json
import logging
import mimetypes
import re
import shutil
import tempfile
import uuid
import zipfile
from collections.abc import Generator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from src.agents.lead_agent.agent import _build_middlewares
from src.agents.lead_agent.prompt import apply_prompt_template
from src.agents.thread_state import ThreadState
from src.config.app_config import get_app_config, reload_app_config
from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from src.models import create_chat_model
from src.config.paths import get_paths
logger = logging.getLogger(__name__)
@dataclass
class StreamEvent:
"""A single event from the streaming agent response.
Attributes:
type: Event type — "message", "tool_call", "tool_result", "title", or "done".
data: Event payload. Contents vary by type.
"""
type: str
data: dict[str, Any] = field(default_factory=dict)
class DeerFlowClient:
"""Embedded Python client for DeerFlow agent system.
Provides direct programmatic access to DeerFlow's agent capabilities
without requiring LangGraph Server or Gateway API processes.
Note:
Multi-turn conversations require a ``checkpointer``. Without one,
each ``stream()`` / ``chat()`` call is stateless — ``thread_id``
is only used for file isolation (uploads / artifacts).
The system prompt (including date, memory, and skills context) is
generated when the internal agent is first created and cached until
the configuration key changes. Call :meth:`reset_agent` to force
a refresh in long-running processes.
Example::
from src.client import DeerFlowClient
client = DeerFlowClient()
# Simple one-shot
print(client.chat("hello"))
# Streaming
for event in client.stream("hello"):
print(event.type, event.data)
# Configuration queries
print(client.list_models())
print(client.list_skills())
"""
def __init__(
self,
config_path: str | None = None,
checkpointer=None,
*,
model_name: str | None = None,
thinking_enabled: bool = True,
subagent_enabled: bool = False,
plan_mode: bool = False,
):
"""Initialize the client.
Loads configuration but defers agent creation to first use.
Args:
config_path: Path to config.yaml. Uses default resolution if None.
checkpointer: LangGraph checkpointer instance for state persistence.
Required for multi-turn conversations on the same thread_id.
Without a checkpointer, each call is stateless.
model_name: Override the default model name from config.
thinking_enabled: Enable model's extended thinking.
subagent_enabled: Enable subagent delegation.
plan_mode: Enable TodoList middleware for plan mode.
"""
if config_path is not None:
reload_app_config(config_path)
self._app_config = get_app_config()
self._checkpointer = checkpointer
self._model_name = model_name
self._thinking_enabled = thinking_enabled
self._subagent_enabled = subagent_enabled
self._plan_mode = plan_mode
# Lazy agent — created on first call, recreated when config changes.
self._agent = None
self._agent_config_key: tuple | None = None
def reset_agent(self) -> None:
"""Force the internal agent to be recreated on the next call.
Use this after external changes (e.g. memory updates, skill
installations) that should be reflected in the system prompt
or tool set.
"""
self._agent = None
self._agent_config_key = None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _atomic_write_json(path: Path, data: dict) -> None:
"""Write JSON to *path* atomically (temp file + replace)."""
fd = tempfile.NamedTemporaryFile(
mode="w", dir=path.parent, suffix=".tmp", delete=False,
)
try:
json.dump(data, fd, indent=2)
fd.close()
Path(fd.name).replace(path)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def _get_runnable_config(self, thread_id: str, **overrides) -> RunnableConfig:
"""Build a RunnableConfig for agent invocation."""
configurable = {
"thread_id": thread_id,
"model_name": overrides.get("model_name", self._model_name),
"thinking_enabled": overrides.get("thinking_enabled", self._thinking_enabled),
"is_plan_mode": overrides.get("plan_mode", self._plan_mode),
"subagent_enabled": overrides.get("subagent_enabled", self._subagent_enabled),
}
return RunnableConfig(
configurable=configurable,
recursion_limit=overrides.get("recursion_limit", 100),
)
def _ensure_agent(self, config: RunnableConfig):
"""Create (or recreate) the agent when config-dependent params change."""
cfg = config.get("configurable", {})
key = (
cfg.get("model_name"),
cfg.get("thinking_enabled"),
cfg.get("is_plan_mode"),
cfg.get("subagent_enabled"),
)
if self._agent is not None and self._agent_config_key == key:
return
thinking_enabled = cfg.get("thinking_enabled", True)
model_name = cfg.get("model_name")
subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
),
"state_schema": ThreadState,
}
if self._checkpointer is not None:
kwargs["checkpointer"] = self._checkpointer
self._agent = create_agent(**kwargs)
self._agent_config_key = key
logger.info("Agent created: model=%s, thinking=%s", model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level."""
from src.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
@staticmethod
def _extract_text(content) -> str:
"""Extract plain text from AIMessage content (str or list of blocks)."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") == "text":
parts.append(block["text"])
return "\n".join(parts) if parts else ""
return str(content)
# ------------------------------------------------------------------
# Public API — conversation
# ------------------------------------------------------------------
def stream(
self,
message: str,
*,
thread_id: str | None = None,
**kwargs,
) -> Generator[StreamEvent, None, None]:
"""Stream a conversation turn, yielding events incrementally.
Each call sends one user message and yields events until the agent
finishes its turn. A ``checkpointer`` must be provided at init time
for multi-turn context to be preserved across calls.
Args:
message: User message text.
thread_id: Thread ID for conversation context. Auto-generated if None.
**kwargs: Override client defaults (model_name, thinking_enabled,
plan_mode, subagent_enabled, recursion_limit).
Yields:
StreamEvent with one of:
- type="message" data={"content": str}
- type="tool_call" data={"name": str, "args": dict, "id": str}
- type="tool_result" data={"name": str, "content": str, "tool_call_id": str}
- type="title" data={"title": str}
- type="done" data={}
"""
if thread_id is None:
thread_id = str(uuid.uuid4())
config = self._get_runnable_config(thread_id, **kwargs)
self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id}
seen_ids: set[str] = set()
last_title: str | None = None
for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"):
messages = chunk.get("messages", [])
for msg in messages:
msg_id = getattr(msg, "id", None)
if msg_id and msg_id in seen_ids:
continue
if msg_id:
seen_ids.add(msg_id)
if isinstance(msg, AIMessage):
if msg.tool_calls:
for tc in msg.tool_calls:
yield StreamEvent(
type="tool_call",
data={"name": tc["name"], "args": tc["args"], "id": tc.get("id")},
)
text = self._extract_text(msg.content)
if text:
yield StreamEvent(type="message", data={"content": text})
elif isinstance(msg, ToolMessage):
yield StreamEvent(
type="tool_result",
data={
"name": getattr(msg, "name", None),
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
"tool_call_id": getattr(msg, "tool_call_id", None),
},
)
# Title changes
title = chunk.get("title")
if title and title != last_title:
last_title = title
yield StreamEvent(type="title", data={"title": title})
yield StreamEvent(type="done", data={})
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
"""Send a message and return the final text response.
Convenience wrapper around :meth:`stream` that returns only the
**last** ``message`` event's text. If the agent emits multiple
message segments in one turn, intermediate segments are discarded.
Use :meth:`stream` directly to capture all events.
Args:
message: User message text.
thread_id: Thread ID for conversation context. Auto-generated if None.
**kwargs: Override client defaults (same as stream()).
Returns:
The last AI message text, or empty string if no response.
"""
last_text = ""
for event in self.stream(message, thread_id=thread_id, **kwargs):
if event.type == "message":
last_text = event.data.get("content", "")
return last_text
# ------------------------------------------------------------------
# Public API — configuration queries
# ------------------------------------------------------------------
def list_models(self) -> list[dict]:
"""List available models from configuration.
Returns:
List of model config dicts.
"""
return [model.model_dump() for model in self._app_config.models]
def list_skills(self, enabled_only: bool = False) -> list[dict]:
"""List available skills.
Args:
enabled_only: If True, only return enabled skills.
Returns:
List of skill info dicts with name, description, category, enabled.
"""
from src.skills.loader import load_skills
return [
{
"name": s.name,
"description": s.description,
"category": s.category,
"enabled": s.enabled,
}
for s in load_skills(enabled_only=enabled_only)
]
def get_memory(self) -> dict:
"""Get current memory data.
Returns:
Memory data dict (see src/agents/memory/updater.py for structure).
"""
from src.agents.memory.updater import get_memory_data
return get_memory_data()
def get_model(self, name: str) -> dict | None:
"""Get a specific model's configuration by name.
Args:
name: Model name.
Returns:
Model config dict, or None if not found.
"""
model = self._app_config.get_model_config(name)
return model.model_dump() if model is not None else None
# ------------------------------------------------------------------
# Public API — MCP configuration
# ------------------------------------------------------------------
def get_mcp_config(self) -> dict[str, dict]:
"""Get MCP server configurations.
Returns:
Dict mapping server name to its config dict.
"""
config = get_extensions_config()
return {name: server.model_dump() for name, server in config.mcp_servers.items()}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict[str, dict]:
"""Update MCP server configurations.
Writes to extensions_config.json and reloads the cache.
Args:
mcp_servers: Dict mapping server name to config dict.
Each value should contain keys like enabled, type, command, args, env, url, etc.
Returns:
The updated MCP config.
Raises:
OSError: If the config file cannot be written.
"""
config_path = ExtensionsConfig.resolve_config_path()
if config_path is None:
raise FileNotFoundError(
"Cannot locate extensions_config.json. "
"Pass config_path to DeerFlowClient or set DEER_FLOW_HOME."
)
current_config = get_extensions_config()
config_data = {
"mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
reloaded = reload_extensions_config()
return {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}
# ------------------------------------------------------------------
# Public API — skills management
# ------------------------------------------------------------------
def get_skill(self, name: str) -> dict | None:
"""Get a specific skill by name.
Args:
name: Skill name.
Returns:
Skill info dict, or None if not found.
"""
from src.skills.loader import load_skills
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
if skill is None:
return None
return {
"name": skill.name,
"description": skill.description,
"license": skill.license,
"category": skill.category,
"enabled": skill.enabled,
}
def update_skill(self, name: str, *, enabled: bool) -> dict:
"""Update a skill's enabled status.
Args:
name: Skill name.
enabled: New enabled status.
Returns:
Updated skill info dict.
Raises:
ValueError: If the skill is not found.
OSError: If the config file cannot be written.
"""
from src.skills.loader import load_skills
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == name), None)
if skill is None:
raise ValueError(f"Skill '{name}' not found")
config_path = ExtensionsConfig.resolve_config_path()
if config_path is None:
raise FileNotFoundError(
"Cannot locate extensions_config.json. "
"Pass config_path to DeerFlowClient or set DEER_FLOW_HOME."
)
extensions_config = get_extensions_config()
extensions_config.skills[name] = SkillStateConfig(enabled=enabled)
config_data = {
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
reload_extensions_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
if updated is None:
raise RuntimeError(f"Skill '{name}' disappeared after update")
return {
"name": updated.name,
"description": updated.description,
"license": updated.license,
"category": updated.category,
"enabled": updated.enabled,
}
def install_skill(self, skill_path: str | Path) -> dict:
"""Install a skill from a .skill archive (ZIP).
Args:
skill_path: Path to the .skill file.
Returns:
Dict with success, skill_name, message.
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file is invalid.
"""
from src.gateway.routers.skills import _validate_skill_frontmatter
from src.skills.loader import get_skills_root_path
path = Path(skill_path)
if not path.exists():
raise FileNotFoundError(f"Skill file not found: {skill_path}")
if not path.is_file():
raise ValueError(f"Path is not a file: {skill_path}")
if path.suffix != ".skill":
raise ValueError("File must have .skill extension")
if not zipfile.is_zipfile(path):
raise ValueError("File is not a valid ZIP archive")
skills_root = get_skills_root_path()
custom_dir = skills_root / "custom"
custom_dir.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
with zipfile.ZipFile(path, "r") as zf:
total_size = sum(info.file_size for info in zf.infolist())
if total_size > 100 * 1024 * 1024:
raise ValueError("Skill archive too large when extracted (>100MB)")
for info in zf.infolist():
if Path(info.filename).is_absolute() or ".." in Path(info.filename).parts:
raise ValueError(f"Unsafe path in archive: {info.filename}")
zf.extractall(tmp_path)
for p in tmp_path.rglob("*"):
if p.is_symlink():
p.unlink()
items = list(tmp_path.iterdir())
if not items:
raise ValueError("Skill archive is empty")
skill_dir = items[0] if len(items) == 1 and items[0].is_dir() else tmp_path
is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir)
if not is_valid:
raise ValueError(f"Invalid skill: {message}")
if not re.fullmatch(r"[a-zA-Z0-9_-]+", skill_name):
raise ValueError(f"Invalid skill name: {skill_name}")
target = custom_dir / skill_name
if target.exists():
raise ValueError(f"Skill '{skill_name}' already exists")
shutil.copytree(skill_dir, target)
return {"success": True, "skill_name": skill_name, "message": f"Skill '{skill_name}' installed successfully"}
# ------------------------------------------------------------------
# Public API — memory management
# ------------------------------------------------------------------
def reload_memory(self) -> dict:
"""Reload memory data from file, forcing cache invalidation.
Returns:
The reloaded memory data dict.
"""
from src.agents.memory.updater import reload_memory_data
return reload_memory_data()
def get_memory_config(self) -> dict:
"""Get memory system configuration.
Returns:
Memory config dict.
"""
from src.config.memory_config import get_memory_config
config = get_memory_config()
return {
"enabled": config.enabled,
"storage_path": config.storage_path,
"debounce_seconds": config.debounce_seconds,
"max_facts": config.max_facts,
"fact_confidence_threshold": config.fact_confidence_threshold,
"injection_enabled": config.injection_enabled,
"max_injection_tokens": config.max_injection_tokens,
}
def get_memory_status(self) -> dict:
"""Get memory status: config + current data.
Returns:
Dict with "config" and "data" keys.
"""
return {
"config": self.get_memory_config(),
"data": self.get_memory(),
}
# ------------------------------------------------------------------
# Public API — file uploads
# ------------------------------------------------------------------
@staticmethod
def _get_uploads_dir(thread_id: str) -> Path:
"""Get (and create) the uploads directory for a thread."""
base = get_paths().sandbox_uploads_dir(thread_id)
base.mkdir(parents=True, exist_ok=True)
return base
def upload_files(self, thread_id: str, files: list[str | Path]) -> list[dict]:
"""Upload local files into a thread's uploads directory.
For PDF, PPT, Excel, and Word files, they are also converted to Markdown.
Args:
thread_id: Target thread ID.
files: List of local file paths to upload.
Returns:
List of file info dicts (filename, size, path, virtual_path).
Raises:
FileNotFoundError: If any file does not exist.
"""
from src.gateway.routers.uploads import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
# Validate all files upfront to avoid partial uploads.
resolved_files = []
for f in files:
p = Path(f)
if not p.exists():
raise FileNotFoundError(f"File not found: {f}")
resolved_files.append(p)
uploads_dir = self._get_uploads_dir(thread_id)
results: list[dict] = []
for src_path in resolved_files:
dest = uploads_dir / src_path.name
shutil.copy2(src_path, dest)
info: dict[str, Any] = {
"filename": src_path.name,
"size": dest.stat().st_size,
"path": str(dest),
"virtual_path": f"/mnt/user-data/uploads/{src_path.name}",
}
if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS:
try:
try:
asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
md_path = pool.submit(lambda: asyncio.run(convert_file_to_markdown(dest))).result()
except RuntimeError:
md_path = asyncio.run(convert_file_to_markdown(dest))
except Exception:
logger.warning("Failed to convert %s to markdown", src_path.name, exc_info=True)
md_path = None
if md_path is not None:
info["markdown_file"] = md_path.name
info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}"
results.append(info)
return results
def list_uploads(self, thread_id: str) -> list[dict]:
"""List files in a thread's uploads directory.
Args:
thread_id: Thread ID.
Returns:
List of file info dicts.
"""
uploads_dir = self._get_uploads_dir(thread_id)
if not uploads_dir.exists():
return []
files = []
for fp in sorted(uploads_dir.iterdir()):
if fp.is_file():
stat = fp.stat()
files.append({
"filename": fp.name,
"size": stat.st_size,
"path": str(fp),
"virtual_path": f"/mnt/user-data/uploads/{fp.name}",
"extension": fp.suffix,
"modified": stat.st_mtime,
})
return files
def delete_upload(self, thread_id: str, filename: str) -> None:
"""Delete a file from a thread's uploads directory.
Args:
thread_id: Thread ID.
filename: Filename to delete.
Raises:
FileNotFoundError: If the file does not exist.
PermissionError: If path traversal is detected.
"""
uploads_dir = self._get_uploads_dir(thread_id)
file_path = (uploads_dir / filename).resolve()
try:
file_path.relative_to(uploads_dir.resolve())
except ValueError:
raise PermissionError("Access denied: path traversal detected")
if not file_path.is_file():
raise FileNotFoundError(f"File not found: {filename}")
file_path.unlink()
# ------------------------------------------------------------------
# Public API — artifacts
# ------------------------------------------------------------------
def get_artifact(self, thread_id: str, path: str) -> tuple[bytes, str]:
"""Read an artifact file produced by the agent.
Args:
thread_id: Thread ID.
path: Virtual path (e.g. "mnt/user-data/outputs/file.txt").
Returns:
Tuple of (file_bytes, mime_type).
Raises:
FileNotFoundError: If the artifact does not exist.
ValueError: If the path is invalid.
"""
virtual_prefix = "mnt/user-data"
clean_path = path.lstrip("/")
if not clean_path.startswith(virtual_prefix):
raise ValueError(f"Path must start with /{virtual_prefix}")
relative = clean_path[len(virtual_prefix):].lstrip("/")
base_dir = get_paths().sandbox_user_data_dir(thread_id)
actual = (base_dir / relative).resolve()
try:
actual.relative_to(base_dir.resolve())
except ValueError:
raise PermissionError("Access denied: path traversal detected")
if not actual.exists():
raise FileNotFoundError(f"Artifact not found: {path}")
if not actual.is_file():
raise ValueError(f"Path is not a file: {path}")
mime_type, _ = mimetypes.guess_type(actual)
return actual.read_bytes(), mime_type or "application/octet-stream"

1292
backend/tests/test_client.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,312 @@
"""Live integration tests for DeerFlowClient with real API.
These tests require a working config.yaml with valid API credentials.
They are skipped in CI and must be run explicitly:
PYTHONPATH=. uv run pytest tests/test_client_live.py -v -s
"""
import json
import os
import tempfile
from pathlib import Path
import pytest
# Skip entire module in CI or when no config.yaml exists
_skip_reason = None
if os.environ.get("CI"):
_skip_reason = "Live tests skipped in CI"
elif not Path(__file__).resolve().parents[2].joinpath("config.yaml").exists():
_skip_reason = "No config.yaml found — live tests require valid API credentials"
if _skip_reason:
pytest.skip(_skip_reason, allow_module_level=True)
from src.client import DeerFlowClient, StreamEvent
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def client():
"""Create a real DeerFlowClient (no mocks)."""
return DeerFlowClient(thinking_enabled=False)
@pytest.fixture
def thread_tmp(tmp_path):
"""Provide a unique thread_id + tmp directory for file operations."""
import uuid
tid = f"live-test-{uuid.uuid4().hex[:8]}"
return tid, tmp_path
# ===========================================================================
# Scenario 1: Basic chat — model responds coherently
# ===========================================================================
class TestLiveBasicChat:
def test_chat_returns_nonempty_string(self, client):
"""chat() returns a non-empty response from the real model."""
response = client.chat("Reply with exactly: HELLO")
assert isinstance(response, str)
assert len(response) > 0
print(f" chat response: {response}")
def test_chat_follows_instruction(self, client):
"""Model can follow a simple instruction."""
response = client.chat("What is 7 * 8? Reply with just the number.")
assert "56" in response
print(f" math response: {response}")
# ===========================================================================
# Scenario 2: Streaming — events arrive in correct order
# ===========================================================================
class TestLiveStreaming:
def test_stream_yields_message_and_done(self, client):
"""stream() produces at least one message event and ends with done."""
events = list(client.stream("Say hi in one word."))
types = [e.type for e in events]
assert "message" in types, f"Expected 'message' event, got: {types}"
assert types[-1] == "done"
for e in events:
assert isinstance(e, StreamEvent)
print(f" [{e.type}] {e.data}")
def test_stream_message_content_nonempty(self, client):
"""Streamed message events contain non-empty content."""
messages = [
e for e in client.stream("What color is the sky? One word.")
if e.type == "message"
]
assert len(messages) >= 1
for m in messages:
assert len(m.data.get("content", "")) > 0
# ===========================================================================
# Scenario 3: Tool use — agent calls a tool and returns result
# ===========================================================================
class TestLiveToolUse:
def test_agent_uses_bash_tool(self, client):
"""Agent uses bash tool when asked to run a command."""
events = list(client.stream(
"Use the bash tool to run: echo 'LIVE_TEST_OK'. "
"Then tell me the output."
))
types = [e.type for e in events]
print(f" event types: {types}")
for e in events:
print(f" [{e.type}] {e.data}")
# Should have tool_call + tool_result + message
assert "tool_call" in types, f"Expected tool_call, got: {types}"
assert "tool_result" in types, f"Expected tool_result, got: {types}"
assert "message" in types
tc = next(e for e in events if e.type == "tool_call")
assert tc.data["name"] == "bash"
tr = next(e for e in events if e.type == "tool_result")
assert "LIVE_TEST_OK" in tr.data["content"]
def test_agent_uses_ls_tool(self, client):
"""Agent uses ls tool to list a directory."""
events = list(client.stream(
"Use the ls tool to list the contents of /mnt/user-data/workspace. "
"Just report what you see."
))
types = [e.type for e in events]
print(f" event types: {types}")
assert "tool_call" in types
tc = next(e for e in events if e.type == "tool_call")
assert tc.data["name"] == "ls"
# ===========================================================================
# Scenario 4: Multi-tool chain — agent chains tools in sequence
# ===========================================================================
class TestLiveMultiToolChain:
def test_write_then_read(self, client):
"""Agent writes a file, then reads it back."""
events = list(client.stream(
"Step 1: Use write_file to write 'integration_test_content' to "
"/mnt/user-data/outputs/live_test.txt. "
"Step 2: Use read_file to read that file back. "
"Step 3: Tell me the content you read."
))
types = [e.type for e in events]
print(f" event types: {types}")
for e in events:
print(f" [{e.type}] {e.data}")
tool_calls = [e for e in events if e.type == "tool_call"]
tool_names = [tc.data["name"] for tc in tool_calls]
assert "write_file" in tool_names, f"Expected write_file, got: {tool_names}"
assert "read_file" in tool_names, f"Expected read_file, got: {tool_names}"
# Final message should mention the content
messages = [e for e in events if e.type == "message"]
final_text = messages[-1].data["content"] if messages else ""
assert "integration_test_content" in final_text.lower() or any(
"integration_test_content" in e.data.get("content", "")
for e in events if e.type == "tool_result"
)
# ===========================================================================
# Scenario 5: File upload lifecycle with real filesystem
# ===========================================================================
class TestLiveFileUpload:
def test_upload_list_delete(self, client, thread_tmp):
"""Upload → list → delete → verify deletion."""
thread_id, tmp_path = thread_tmp
# Create test files
f1 = tmp_path / "test_upload_a.txt"
f1.write_text("content A")
f2 = tmp_path / "test_upload_b.txt"
f2.write_text("content B")
# Upload
results = client.upload_files(thread_id, [f1, f2])
assert len(results) == 2
filenames = {r["filename"] for r in results}
assert filenames == {"test_upload_a.txt", "test_upload_b.txt"}
for r in results:
assert r["size"] > 0
assert r["virtual_path"].startswith("/mnt/user-data/uploads/")
print(f" uploaded: {filenames}")
# List
listed = client.list_uploads(thread_id)
assert len(listed) == 2
print(f" listed: {[f['filename'] for f in listed]}")
# Delete one
client.delete_upload(thread_id, "test_upload_a.txt")
remaining = client.list_uploads(thread_id)
assert len(remaining) == 1
assert remaining[0]["filename"] == "test_upload_b.txt"
print(f" after delete: {[f['filename'] for f in remaining]}")
# Delete the other
client.delete_upload(thread_id, "test_upload_b.txt")
assert client.list_uploads(thread_id) == []
def test_upload_nonexistent_file_raises(self, client):
with pytest.raises(FileNotFoundError):
client.upload_files("t-fail", ["/nonexistent/path/file.txt"])
# ===========================================================================
# Scenario 6: Configuration query — real config loading
# ===========================================================================
class TestLiveConfigQueries:
def test_list_models_returns_ark(self, client):
"""list_models() returns the configured ARK model."""
models = client.list_models()
assert len(models) >= 1
names = [m["name"] for m in models]
assert "ark-model" in names
print(f" models: {names}")
def test_get_model_found(self, client):
"""get_model() returns details for existing model."""
model = client.get_model("ark-model")
assert model is not None
assert model["name"] == "ark-model"
print(f" model detail: {model}")
def test_get_model_not_found(self, client):
assert client.get_model("nonexistent-model-xyz") is None
def test_list_skills(self, client):
"""list_skills() runs without error."""
skills = client.list_skills()
assert isinstance(skills, list)
print(f" skills count: {len(skills)}")
for s in skills[:3]:
print(f" - {s['name']}: {s['enabled']}")
# ===========================================================================
# Scenario 7: Artifact read after agent writes
# ===========================================================================
class TestLiveArtifact:
def test_get_artifact_after_write(self, client):
"""Agent writes a file → client reads it back via get_artifact()."""
import uuid
thread_id = f"live-artifact-{uuid.uuid4().hex[:8]}"
# Ask agent to write a file
events = list(client.stream(
"Use write_file to create /mnt/user-data/outputs/artifact_test.json "
"with content: {\"status\": \"ok\", \"source\": \"live_test\"}",
thread_id=thread_id,
))
# Verify write happened
tool_calls = [e for e in events if e.type == "tool_call"]
assert any(tc.data["name"] == "write_file" for tc in tool_calls)
# Read artifact
content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json")
data = json.loads(content)
assert data["status"] == "ok"
assert data["source"] == "live_test"
assert "json" in mime
print(f" artifact: {data}, mime: {mime}")
def test_get_artifact_not_found(self, client):
with pytest.raises(FileNotFoundError):
client.get_artifact("nonexistent-thread", "mnt/user-data/outputs/nope.txt")
# ===========================================================================
# Scenario 8: Per-call overrides
# ===========================================================================
class TestLiveOverrides:
def test_thinking_disabled_still_works(self, client):
"""Explicit thinking_enabled=False override produces a response."""
response = client.chat(
"Say OK.", thinking_enabled=False,
)
assert len(response) > 0
print(f" response: {response}")
# ===========================================================================
# Scenario 9: Error resilience
# ===========================================================================
class TestLiveErrorResilience:
def test_delete_nonexistent_upload(self, client):
with pytest.raises(FileNotFoundError):
client.delete_upload("nonexistent-thread", "ghost.txt")
def test_bad_artifact_path(self, client):
with pytest.raises(ValueError):
client.get_artifact("t", "invalid/path")
def test_path_traversal_blocked(self, client):
with pytest.raises(PermissionError):
client.delete_upload("t", "../../etc/passwd")