diff --git a/.gitignore b/.gitignore index 91cfc7e..1ec225f 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ coverage/ .claude/ skills/custom/* logs/ +log/ # Local git hooks (keep only on this machine, do not push) .githooks/ diff --git a/README.md b/README.md index 4051db3..91a6cd3 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,31 @@ DeerFlow is model-agnostic — it works with any LLM that implements the OpenAI- - **Multimodal inputs** for image understanding and video comprehension - **Strong tool-use** for reliable function calling and structured outputs +## Embedded Python Client + +DeerFlow can be used as an embedded Python library without running the full HTTP services. The `DeerFlowClient` provides direct in-process access to all agent and Gateway capabilities: + +```python +from src.client import DeerFlowClient + +client = DeerFlowClient() + +# Chat +response = client.chat("Analyze this paper for me", thread_id="my-thread") + +# Streaming +for event in client.stream("hello"): + print(event.type, event.data) + +# Configuration & management +print(client.list_models()) +print(client.list_skills()) +client.update_skill("web-search", enabled=True) +client.upload_files("thread-1", ["./report.pdf"]) +``` + +See `backend/src/client.py` for full API documentation. + ## Documentation - [Contributing Guide](CONTRIBUTING.md) - Development environment setup and workflow diff --git a/backend/.gitignore b/backend/.gitignore index 231ce2b..6e56d9e 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -11,6 +11,9 @@ wheels/ agent_history.gif static/browser_history/*.gif +log/ +log/* + # Virtual environments .venv venv/ diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index ae93d0a..1e55289 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -47,7 +47,8 @@ deer-flow/ │ │ ├── config/ # Configuration system (app, model, sandbox, tool, etc.) │ │ ├── community/ # Community tools (tavily, jina_ai, firecrawl, image_search, aio_sandbox) │ │ ├── reflection/ # Dynamic module loading (resolve_variable, resolve_class) -│ │ └── utils/ # Utilities (network, readability) +│ │ ├── utils/ # Utilities (network, readability) +│ │ └── client.py # Embedded Python client (DeerFlowClient) │ ├── tests/ # Test suite │ └── docs/ # Documentation ├── frontend/ # Next.js frontend application @@ -289,7 +290,35 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → - `mcpServers` - Map of server name → config (enabled, type, command, args, env, url, headers, description) - `skills` - Map of skill name → state (enabled) -Both can be modified at runtime via Gateway API endpoints. +Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` methods. + +### Embedded Client (`src/client.py`) + +`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. + +**Architecture**: Imports the same `src/` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency. + +**Agent Conversation** (replaces LangGraph Server): +- `chat(message, thread_id)` — synchronous, returns final text +- `stream(message, thread_id)` — yields `StreamEvent` (message, tool_call, tool_result, title, done) +- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent` +- Supports `checkpointer` parameter for state persistence across turns +- Invocation pattern: `agent.stream(state, config, context, stream_mode="values")` + +**Gateway Equivalent Methods** (replaces Gateway API): + +| Category | Methods | +|----------|---------| +| Models | `list_models()`, `get_model(name)` | +| MCP | `get_mcp_config()`, `update_mcp_config(servers)` | +| Skills | `list_skills()`, `get_skill(name)`, `update_skill(name, enabled)`, `install_skill(path)` | +| Memory | `get_memory()`, `reload_memory()`, `get_memory_config()`, `get_memory_status()` | +| Uploads | `upload_files(thread_id, files)`, `list_uploads(thread_id)`, `delete_upload(thread_id, filename)` | +| Artifacts | `get_artifact(thread_id, path)` → `(bytes, mime_type)` | + +**Key difference from Gateway**: Upload accepts local `Path` objects instead of HTTP `UploadFile`. Artifact returns `(bytes, mime_type)` instead of HTTP Response. + +**Tests**: `tests/test_client.py` (45 unit tests) ## Development Workflow diff --git a/backend/src/client.py b/backend/src/client.py new file mode 100644 index 0000000..a859c76 --- /dev/null +++ b/backend/src/client.py @@ -0,0 +1,786 @@ +"""DeerFlowClient — Embedded Python client for DeerFlow agent system. + +Provides direct programmatic access to DeerFlow's agent capabilities +without requiring LangGraph Server or Gateway API processes. + +Usage: + from src.client import DeerFlowClient + + client = DeerFlowClient() + response = client.chat("Analyze this paper for me", thread_id="my-thread") + print(response) + + # Streaming + for event in client.stream("hello"): + print(event) +""" + +import asyncio +import json +import logging +import mimetypes +import re +import shutil +import tempfile +import uuid +import zipfile +from collections.abc import Generator +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from langchain.agents import create_agent +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.runnables import RunnableConfig + +from src.agents.lead_agent.agent import _build_middlewares +from src.agents.lead_agent.prompt import apply_prompt_template +from src.agents.thread_state import ThreadState +from src.config.app_config import get_app_config, reload_app_config +from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config +from src.models import create_chat_model +from src.config.paths import get_paths + +logger = logging.getLogger(__name__) + + +@dataclass +class StreamEvent: + """A single event from the streaming agent response. + + Attributes: + type: Event type — "message", "tool_call", "tool_result", "title", or "done". + data: Event payload. Contents vary by type. + """ + + type: str + data: dict[str, Any] = field(default_factory=dict) + + +class DeerFlowClient: + """Embedded Python client for DeerFlow agent system. + + Provides direct programmatic access to DeerFlow's agent capabilities + without requiring LangGraph Server or Gateway API processes. + + Note: + Multi-turn conversations require a ``checkpointer``. Without one, + each ``stream()`` / ``chat()`` call is stateless — ``thread_id`` + is only used for file isolation (uploads / artifacts). + + The system prompt (including date, memory, and skills context) is + generated when the internal agent is first created and cached until + the configuration key changes. Call :meth:`reset_agent` to force + a refresh in long-running processes. + + Example:: + + from src.client import DeerFlowClient + + client = DeerFlowClient() + + # Simple one-shot + print(client.chat("hello")) + + # Streaming + for event in client.stream("hello"): + print(event.type, event.data) + + # Configuration queries + print(client.list_models()) + print(client.list_skills()) + """ + + def __init__( + self, + config_path: str | None = None, + checkpointer=None, + *, + model_name: str | None = None, + thinking_enabled: bool = True, + subagent_enabled: bool = False, + plan_mode: bool = False, + ): + """Initialize the client. + + Loads configuration but defers agent creation to first use. + + Args: + config_path: Path to config.yaml. Uses default resolution if None. + checkpointer: LangGraph checkpointer instance for state persistence. + Required for multi-turn conversations on the same thread_id. + Without a checkpointer, each call is stateless. + model_name: Override the default model name from config. + thinking_enabled: Enable model's extended thinking. + subagent_enabled: Enable subagent delegation. + plan_mode: Enable TodoList middleware for plan mode. + """ + if config_path is not None: + reload_app_config(config_path) + self._app_config = get_app_config() + + self._checkpointer = checkpointer + self._model_name = model_name + self._thinking_enabled = thinking_enabled + self._subagent_enabled = subagent_enabled + self._plan_mode = plan_mode + + # Lazy agent — created on first call, recreated when config changes. + self._agent = None + self._agent_config_key: tuple | None = None + + def reset_agent(self) -> None: + """Force the internal agent to be recreated on the next call. + + Use this after external changes (e.g. memory updates, skill + installations) that should be reflected in the system prompt + or tool set. + """ + self._agent = None + self._agent_config_key = None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _atomic_write_json(path: Path, data: dict) -> None: + """Write JSON to *path* atomically (temp file + replace).""" + fd = tempfile.NamedTemporaryFile( + mode="w", dir=path.parent, suffix=".tmp", delete=False, + ) + try: + json.dump(data, fd, indent=2) + fd.close() + Path(fd.name).replace(path) + except BaseException: + fd.close() + Path(fd.name).unlink(missing_ok=True) + raise + + def _get_runnable_config(self, thread_id: str, **overrides) -> RunnableConfig: + """Build a RunnableConfig for agent invocation.""" + configurable = { + "thread_id": thread_id, + "model_name": overrides.get("model_name", self._model_name), + "thinking_enabled": overrides.get("thinking_enabled", self._thinking_enabled), + "is_plan_mode": overrides.get("plan_mode", self._plan_mode), + "subagent_enabled": overrides.get("subagent_enabled", self._subagent_enabled), + } + return RunnableConfig( + configurable=configurable, + recursion_limit=overrides.get("recursion_limit", 100), + ) + + def _ensure_agent(self, config: RunnableConfig): + """Create (or recreate) the agent when config-dependent params change.""" + cfg = config.get("configurable", {}) + key = ( + cfg.get("model_name"), + cfg.get("thinking_enabled"), + cfg.get("is_plan_mode"), + cfg.get("subagent_enabled"), + ) + + if self._agent is not None and self._agent_config_key == key: + return + + thinking_enabled = cfg.get("thinking_enabled", True) + model_name = cfg.get("model_name") + subagent_enabled = cfg.get("subagent_enabled", False) + max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) + + kwargs: dict[str, Any] = { + "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), + "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), + "middleware": _build_middlewares(config, model_name=model_name), + "system_prompt": apply_prompt_template( + subagent_enabled=subagent_enabled, + max_concurrent_subagents=max_concurrent_subagents, + ), + "state_schema": ThreadState, + } + if self._checkpointer is not None: + kwargs["checkpointer"] = self._checkpointer + + self._agent = create_agent(**kwargs) + self._agent_config_key = key + logger.info("Agent created: model=%s, thinking=%s", model_name, thinking_enabled) + + @staticmethod + def _get_tools(*, model_name: str | None, subagent_enabled: bool): + """Lazy import to avoid circular dependency at module level.""" + from src.tools import get_available_tools + + return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + + @staticmethod + def _extract_text(content) -> str: + """Extract plain text from AIMessage content (str or list of blocks).""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict) and block.get("type") == "text": + parts.append(block["text"]) + return "\n".join(parts) if parts else "" + return str(content) + + # ------------------------------------------------------------------ + # Public API — conversation + # ------------------------------------------------------------------ + + def stream( + self, + message: str, + *, + thread_id: str | None = None, + **kwargs, + ) -> Generator[StreamEvent, None, None]: + """Stream a conversation turn, yielding events incrementally. + + Each call sends one user message and yields events until the agent + finishes its turn. A ``checkpointer`` must be provided at init time + for multi-turn context to be preserved across calls. + + Args: + message: User message text. + thread_id: Thread ID for conversation context. Auto-generated if None. + **kwargs: Override client defaults (model_name, thinking_enabled, + plan_mode, subagent_enabled, recursion_limit). + + Yields: + StreamEvent with one of: + - type="message" data={"content": str} + - type="tool_call" data={"name": str, "args": dict, "id": str} + - type="tool_result" data={"name": str, "content": str, "tool_call_id": str} + - type="title" data={"title": str} + - type="done" data={} + """ + if thread_id is None: + thread_id = str(uuid.uuid4()) + + config = self._get_runnable_config(thread_id, **kwargs) + self._ensure_agent(config) + + state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} + context = {"thread_id": thread_id} + + seen_ids: set[str] = set() + last_title: str | None = None + + for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"): + messages = chunk.get("messages", []) + + for msg in messages: + msg_id = getattr(msg, "id", None) + if msg_id and msg_id in seen_ids: + continue + if msg_id: + seen_ids.add(msg_id) + + if isinstance(msg, AIMessage): + if msg.tool_calls: + for tc in msg.tool_calls: + yield StreamEvent( + type="tool_call", + data={"name": tc["name"], "args": tc["args"], "id": tc.get("id")}, + ) + + text = self._extract_text(msg.content) + if text: + yield StreamEvent(type="message", data={"content": text}) + + elif isinstance(msg, ToolMessage): + yield StreamEvent( + type="tool_result", + data={ + "name": getattr(msg, "name", None), + "content": msg.content if isinstance(msg.content, str) else str(msg.content), + "tool_call_id": getattr(msg, "tool_call_id", None), + }, + ) + + # Title changes + title = chunk.get("title") + if title and title != last_title: + last_title = title + yield StreamEvent(type="title", data={"title": title}) + + yield StreamEvent(type="done", data={}) + + def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str: + """Send a message and return the final text response. + + Convenience wrapper around :meth:`stream` that returns only the + **last** ``message`` event's text. If the agent emits multiple + message segments in one turn, intermediate segments are discarded. + Use :meth:`stream` directly to capture all events. + + Args: + message: User message text. + thread_id: Thread ID for conversation context. Auto-generated if None. + **kwargs: Override client defaults (same as stream()). + + Returns: + The last AI message text, or empty string if no response. + """ + last_text = "" + for event in self.stream(message, thread_id=thread_id, **kwargs): + if event.type == "message": + last_text = event.data.get("content", "") + return last_text + + # ------------------------------------------------------------------ + # Public API — configuration queries + # ------------------------------------------------------------------ + + def list_models(self) -> list[dict]: + """List available models from configuration. + + Returns: + List of model config dicts. + """ + return [model.model_dump() for model in self._app_config.models] + + def list_skills(self, enabled_only: bool = False) -> list[dict]: + """List available skills. + + Args: + enabled_only: If True, only return enabled skills. + + Returns: + List of skill info dicts with name, description, category, enabled. + """ + from src.skills.loader import load_skills + + return [ + { + "name": s.name, + "description": s.description, + "category": s.category, + "enabled": s.enabled, + } + for s in load_skills(enabled_only=enabled_only) + ] + + def get_memory(self) -> dict: + """Get current memory data. + + Returns: + Memory data dict (see src/agents/memory/updater.py for structure). + """ + from src.agents.memory.updater import get_memory_data + + return get_memory_data() + + def get_model(self, name: str) -> dict | None: + """Get a specific model's configuration by name. + + Args: + name: Model name. + + Returns: + Model config dict, or None if not found. + """ + model = self._app_config.get_model_config(name) + return model.model_dump() if model is not None else None + + # ------------------------------------------------------------------ + # Public API — MCP configuration + # ------------------------------------------------------------------ + + def get_mcp_config(self) -> dict[str, dict]: + """Get MCP server configurations. + + Returns: + Dict mapping server name to its config dict. + """ + config = get_extensions_config() + return {name: server.model_dump() for name, server in config.mcp_servers.items()} + + def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict[str, dict]: + """Update MCP server configurations. + + Writes to extensions_config.json and reloads the cache. + + Args: + mcp_servers: Dict mapping server name to config dict. + Each value should contain keys like enabled, type, command, args, env, url, etc. + + Returns: + The updated MCP config. + + Raises: + OSError: If the config file cannot be written. + """ + config_path = ExtensionsConfig.resolve_config_path() + if config_path is None: + raise FileNotFoundError( + "Cannot locate extensions_config.json. " + "Pass config_path to DeerFlowClient or set DEER_FLOW_HOME." + ) + + current_config = get_extensions_config() + + config_data = { + "mcpServers": mcp_servers, + "skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, + } + + self._atomic_write_json(config_path, config_data) + + self._agent = None + reloaded = reload_extensions_config() + return {name: server.model_dump() for name, server in reloaded.mcp_servers.items()} + + # ------------------------------------------------------------------ + # Public API — skills management + # ------------------------------------------------------------------ + + def get_skill(self, name: str) -> dict | None: + """Get a specific skill by name. + + Args: + name: Skill name. + + Returns: + Skill info dict, or None if not found. + """ + from src.skills.loader import load_skills + + skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None) + if skill is None: + return None + return { + "name": skill.name, + "description": skill.description, + "license": skill.license, + "category": skill.category, + "enabled": skill.enabled, + } + + def update_skill(self, name: str, *, enabled: bool) -> dict: + """Update a skill's enabled status. + + Args: + name: Skill name. + enabled: New enabled status. + + Returns: + Updated skill info dict. + + Raises: + ValueError: If the skill is not found. + OSError: If the config file cannot be written. + """ + from src.skills.loader import load_skills + + skills = load_skills(enabled_only=False) + skill = next((s for s in skills if s.name == name), None) + if skill is None: + raise ValueError(f"Skill '{name}' not found") + + config_path = ExtensionsConfig.resolve_config_path() + if config_path is None: + raise FileNotFoundError( + "Cannot locate extensions_config.json. " + "Pass config_path to DeerFlowClient or set DEER_FLOW_HOME." + ) + + extensions_config = get_extensions_config() + extensions_config.skills[name] = SkillStateConfig(enabled=enabled) + + config_data = { + "mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()}, + "skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()}, + } + + self._atomic_write_json(config_path, config_data) + + self._agent = None + reload_extensions_config() + + updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None) + if updated is None: + raise RuntimeError(f"Skill '{name}' disappeared after update") + return { + "name": updated.name, + "description": updated.description, + "license": updated.license, + "category": updated.category, + "enabled": updated.enabled, + } + + def install_skill(self, skill_path: str | Path) -> dict: + """Install a skill from a .skill archive (ZIP). + + Args: + skill_path: Path to the .skill file. + + Returns: + Dict with success, skill_name, message. + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If the file is invalid. + """ + from src.gateway.routers.skills import _validate_skill_frontmatter + from src.skills.loader import get_skills_root_path + + path = Path(skill_path) + if not path.exists(): + raise FileNotFoundError(f"Skill file not found: {skill_path}") + if not path.is_file(): + raise ValueError(f"Path is not a file: {skill_path}") + if path.suffix != ".skill": + raise ValueError("File must have .skill extension") + if not zipfile.is_zipfile(path): + raise ValueError("File is not a valid ZIP archive") + + skills_root = get_skills_root_path() + custom_dir = skills_root / "custom" + custom_dir.mkdir(parents=True, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + with zipfile.ZipFile(path, "r") as zf: + total_size = sum(info.file_size for info in zf.infolist()) + if total_size > 100 * 1024 * 1024: + raise ValueError("Skill archive too large when extracted (>100MB)") + for info in zf.infolist(): + if Path(info.filename).is_absolute() or ".." in Path(info.filename).parts: + raise ValueError(f"Unsafe path in archive: {info.filename}") + zf.extractall(tmp_path) + for p in tmp_path.rglob("*"): + if p.is_symlink(): + p.unlink() + + items = list(tmp_path.iterdir()) + if not items: + raise ValueError("Skill archive is empty") + + skill_dir = items[0] if len(items) == 1 and items[0].is_dir() else tmp_path + + is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir) + if not is_valid: + raise ValueError(f"Invalid skill: {message}") + if not re.fullmatch(r"[a-zA-Z0-9_-]+", skill_name): + raise ValueError(f"Invalid skill name: {skill_name}") + + target = custom_dir / skill_name + if target.exists(): + raise ValueError(f"Skill '{skill_name}' already exists") + + shutil.copytree(skill_dir, target) + + return {"success": True, "skill_name": skill_name, "message": f"Skill '{skill_name}' installed successfully"} + + # ------------------------------------------------------------------ + # Public API — memory management + # ------------------------------------------------------------------ + + def reload_memory(self) -> dict: + """Reload memory data from file, forcing cache invalidation. + + Returns: + The reloaded memory data dict. + """ + from src.agents.memory.updater import reload_memory_data + + return reload_memory_data() + + def get_memory_config(self) -> dict: + """Get memory system configuration. + + Returns: + Memory config dict. + """ + from src.config.memory_config import get_memory_config + + config = get_memory_config() + return { + "enabled": config.enabled, + "storage_path": config.storage_path, + "debounce_seconds": config.debounce_seconds, + "max_facts": config.max_facts, + "fact_confidence_threshold": config.fact_confidence_threshold, + "injection_enabled": config.injection_enabled, + "max_injection_tokens": config.max_injection_tokens, + } + + def get_memory_status(self) -> dict: + """Get memory status: config + current data. + + Returns: + Dict with "config" and "data" keys. + """ + return { + "config": self.get_memory_config(), + "data": self.get_memory(), + } + + # ------------------------------------------------------------------ + # Public API — file uploads + # ------------------------------------------------------------------ + + @staticmethod + def _get_uploads_dir(thread_id: str) -> Path: + """Get (and create) the uploads directory for a thread.""" + base = get_paths().sandbox_uploads_dir(thread_id) + base.mkdir(parents=True, exist_ok=True) + return base + + def upload_files(self, thread_id: str, files: list[str | Path]) -> list[dict]: + """Upload local files into a thread's uploads directory. + + For PDF, PPT, Excel, and Word files, they are also converted to Markdown. + + Args: + thread_id: Target thread ID. + files: List of local file paths to upload. + + Returns: + List of file info dicts (filename, size, path, virtual_path). + + Raises: + FileNotFoundError: If any file does not exist. + """ + from src.gateway.routers.uploads import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown + + # Validate all files upfront to avoid partial uploads. + resolved_files = [] + for f in files: + p = Path(f) + if not p.exists(): + raise FileNotFoundError(f"File not found: {f}") + resolved_files.append(p) + + uploads_dir = self._get_uploads_dir(thread_id) + results: list[dict] = [] + + for src_path in resolved_files: + + dest = uploads_dir / src_path.name + shutil.copy2(src_path, dest) + + info: dict[str, Any] = { + "filename": src_path.name, + "size": dest.stat().st_size, + "path": str(dest), + "virtual_path": f"/mnt/user-data/uploads/{src_path.name}", + } + + if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS: + try: + try: + asyncio.get_running_loop() + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + md_path = pool.submit(lambda: asyncio.run(convert_file_to_markdown(dest))).result() + except RuntimeError: + md_path = asyncio.run(convert_file_to_markdown(dest)) + except Exception: + logger.warning("Failed to convert %s to markdown", src_path.name, exc_info=True) + md_path = None + + if md_path is not None: + info["markdown_file"] = md_path.name + info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}" + + results.append(info) + + return results + + def list_uploads(self, thread_id: str) -> list[dict]: + """List files in a thread's uploads directory. + + Args: + thread_id: Thread ID. + + Returns: + List of file info dicts. + """ + uploads_dir = self._get_uploads_dir(thread_id) + if not uploads_dir.exists(): + return [] + + files = [] + for fp in sorted(uploads_dir.iterdir()): + if fp.is_file(): + stat = fp.stat() + files.append({ + "filename": fp.name, + "size": stat.st_size, + "path": str(fp), + "virtual_path": f"/mnt/user-data/uploads/{fp.name}", + "extension": fp.suffix, + "modified": stat.st_mtime, + }) + return files + + def delete_upload(self, thread_id: str, filename: str) -> None: + """Delete a file from a thread's uploads directory. + + Args: + thread_id: Thread ID. + filename: Filename to delete. + + Raises: + FileNotFoundError: If the file does not exist. + PermissionError: If path traversal is detected. + """ + uploads_dir = self._get_uploads_dir(thread_id) + file_path = (uploads_dir / filename).resolve() + + try: + file_path.relative_to(uploads_dir.resolve()) + except ValueError: + raise PermissionError("Access denied: path traversal detected") + + if not file_path.is_file(): + raise FileNotFoundError(f"File not found: {filename}") + + file_path.unlink() + + # ------------------------------------------------------------------ + # Public API — artifacts + # ------------------------------------------------------------------ + + def get_artifact(self, thread_id: str, path: str) -> tuple[bytes, str]: + """Read an artifact file produced by the agent. + + Args: + thread_id: Thread ID. + path: Virtual path (e.g. "mnt/user-data/outputs/file.txt"). + + Returns: + Tuple of (file_bytes, mime_type). + + Raises: + FileNotFoundError: If the artifact does not exist. + ValueError: If the path is invalid. + """ + virtual_prefix = "mnt/user-data" + clean_path = path.lstrip("/") + if not clean_path.startswith(virtual_prefix): + raise ValueError(f"Path must start with /{virtual_prefix}") + + relative = clean_path[len(virtual_prefix):].lstrip("/") + base_dir = get_paths().sandbox_user_data_dir(thread_id) + actual = (base_dir / relative).resolve() + + try: + actual.relative_to(base_dir.resolve()) + except ValueError: + raise PermissionError("Access denied: path traversal detected") + if not actual.exists(): + raise FileNotFoundError(f"Artifact not found: {path}") + if not actual.is_file(): + raise ValueError(f"Path is not a file: {path}") + + mime_type, _ = mimetypes.guess_type(actual) + return actual.read_bytes(), mime_type or "application/octet-stream" diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py new file mode 100644 index 0000000..729a1fe --- /dev/null +++ b/backend/tests/test_client.py @@ -0,0 +1,1292 @@ +"""Tests for DeerFlowClient.""" + +import json +import tempfile +import zipfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401 + +from src.client import DeerFlowClient + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_app_config(): + """Provide a minimal AppConfig mock.""" + model = MagicMock() + model.name = "test-model" + model.model_dump.return_value = {"name": "test-model", "use": "langchain_openai:ChatOpenAI"} + + config = MagicMock() + config.models = [model] + return config + + +@pytest.fixture +def client(mock_app_config): + """Create a DeerFlowClient with mocked config loading.""" + with patch("src.client.get_app_config", return_value=mock_app_config): + return DeerFlowClient() + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + +class TestClientInit: + def test_default_params(self, client): + assert client._model_name is None + assert client._thinking_enabled is True + assert client._subagent_enabled is False + assert client._plan_mode is False + assert client._checkpointer is None + assert client._agent is None + + def test_custom_params(self, mock_app_config): + with patch("src.client.get_app_config", return_value=mock_app_config): + c = DeerFlowClient( + model_name="gpt-4", + thinking_enabled=False, + subagent_enabled=True, + plan_mode=True, + ) + assert c._model_name == "gpt-4" + assert c._thinking_enabled is False + assert c._subagent_enabled is True + assert c._plan_mode is True + + def test_custom_config_path(self, mock_app_config): + with ( + patch("src.client.reload_app_config") as mock_reload, + patch("src.client.get_app_config", return_value=mock_app_config), + ): + DeerFlowClient(config_path="/tmp/custom.yaml") + mock_reload.assert_called_once_with("/tmp/custom.yaml") + + def test_checkpointer_stored(self, mock_app_config): + cp = MagicMock() + with patch("src.client.get_app_config", return_value=mock_app_config): + c = DeerFlowClient(checkpointer=cp) + assert c._checkpointer is cp + + +# --------------------------------------------------------------------------- +# list_models / list_skills / get_memory +# --------------------------------------------------------------------------- + +class TestConfigQueries: + def test_list_models(self, client): + models = client.list_models() + assert len(models) == 1 + assert models[0]["name"] == "test-model" + + def test_list_skills(self, client): + skill = MagicMock() + skill.name = "web-search" + skill.description = "Search the web" + skill.category = "public" + skill.enabled = True + + with patch("src.skills.loader.load_skills", return_value=[skill]) as mock_load: + result = client.list_skills() + mock_load.assert_called_once_with(enabled_only=False) + + assert len(result) == 1 + assert result[0] == { + "name": "web-search", + "description": "Search the web", + "category": "public", + "enabled": True, + } + + def test_list_skills_enabled_only(self, client): + with patch("src.skills.loader.load_skills", return_value=[]) as mock_load: + client.list_skills(enabled_only=True) + mock_load.assert_called_once_with(enabled_only=True) + + def test_get_memory(self, client): + memory = {"version": "1.0", "facts": []} + with patch("src.agents.memory.updater.get_memory_data", return_value=memory) as mock_mem: + result = client.get_memory() + mock_mem.assert_called_once() + assert result == memory + + +# --------------------------------------------------------------------------- +# stream / chat +# --------------------------------------------------------------------------- + +def _make_agent_mock(chunks: list[dict]): + """Create a mock agent whose .stream() yields the given chunks.""" + agent = MagicMock() + agent.stream.return_value = iter(chunks) + return agent + + +class TestStream: + def test_basic_message(self, client): + """stream() emits message + done for a simple AI reply.""" + ai = AIMessage(content="Hello!", id="ai-1") + chunks = [ + {"messages": [HumanMessage(content="hi", id="h-1")]}, + {"messages": [HumanMessage(content="hi", id="h-1"), ai]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t1")) + + types = [e.type for e in events] + assert "message" in types + assert types[-1] == "done" + msg_events = [e for e in events if e.type == "message"] + assert msg_events[0].data["content"] == "Hello!" + + def test_tool_call_and_result(self, client): + """stream() emits tool_call and tool_result events.""" + ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}]) + tool = ToolMessage(content="file.txt", id="tm-1", tool_call_id="tc-1", name="bash") + ai2 = AIMessage(content="Here are the files.", id="ai-2") + + chunks = [ + {"messages": [HumanMessage(content="list files", id="h-1"), ai]}, + {"messages": [HumanMessage(content="list files", id="h-1"), ai, tool]}, + {"messages": [HumanMessage(content="list files", id="h-1"), ai, tool, ai2]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("list files", thread_id="t2")) + + types = [e.type for e in events] + assert "tool_call" in types + assert "tool_result" in types + assert "message" in types + assert types[-1] == "done" + + def test_title_event(self, client): + """stream() emits title event when title appears in state.""" + ai = AIMessage(content="ok", id="ai-1") + chunks = [ + {"messages": [HumanMessage(content="hi", id="h-1"), ai], "title": "Greeting"}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t3")) + + title_events = [e for e in events if e.type == "title"] + assert len(title_events) == 1 + assert title_events[0].data["title"] == "Greeting" + + def test_deduplication(self, client): + """Messages with the same id are not emitted twice.""" + ai = AIMessage(content="Hello!", id="ai-1") + chunks = [ + {"messages": [HumanMessage(content="hi", id="h-1"), ai]}, + {"messages": [HumanMessage(content="hi", id="h-1"), ai]}, # duplicate + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t4")) + + msg_events = [e for e in events if e.type == "message"] + assert len(msg_events) == 1 + + def test_auto_thread_id(self, client): + """stream() auto-generates a thread_id if not provided.""" + agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}]) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi")) + + # Should not raise; done event proves it completed + assert events[-1].type == "done" + + def test_list_content_blocks(self, client): + """stream() handles AIMessage with list-of-blocks content.""" + ai = AIMessage( + content=[ + {"type": "thinking", "thinking": "hmm"}, + {"type": "text", "text": "result"}, + ], + id="ai-1", + ) + chunks = [{"messages": [ai]}] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t5")) + + msg_events = [e for e in events if e.type == "message"] + assert len(msg_events) == 1 + assert msg_events[0].data["content"] == "result" + + +class TestChat: + def test_returns_last_message(self, client): + """chat() returns the last AI message text.""" + ai1 = AIMessage(content="thinking...", id="ai-1") + ai2 = AIMessage(content="final answer", id="ai-2") + chunks = [ + {"messages": [HumanMessage(content="q", id="h-1"), ai1]}, + {"messages": [HumanMessage(content="q", id="h-1"), ai1, ai2]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + result = client.chat("q", thread_id="t6") + + assert result == "final answer" + + def test_empty_response(self, client): + """chat() returns empty string if no AI message produced.""" + chunks = [{"messages": []}] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + result = client.chat("q", thread_id="t7") + + assert result == "" + + +# --------------------------------------------------------------------------- +# _extract_text +# --------------------------------------------------------------------------- + +class TestExtractText: + def test_string(self): + assert DeerFlowClient._extract_text("hello") == "hello" + + def test_list_text_blocks(self): + content = [ + {"type": "text", "text": "first"}, + {"type": "thinking", "thinking": "skip"}, + {"type": "text", "text": "second"}, + ] + assert DeerFlowClient._extract_text(content) == "first\nsecond" + + def test_list_plain_strings(self): + assert DeerFlowClient._extract_text(["a", "b"]) == "a\nb" + + def test_empty_list(self): + assert DeerFlowClient._extract_text([]) == "" + + def test_other_type(self): + assert DeerFlowClient._extract_text(42) == "42" + + +# --------------------------------------------------------------------------- +# _ensure_agent +# --------------------------------------------------------------------------- + +class TestEnsureAgent: + def test_creates_agent(self, client): + """_ensure_agent creates an agent on first call.""" + mock_agent = MagicMock() + config = client._get_runnable_config("t1") + + with ( + patch("src.client.create_chat_model"), + patch("src.client.create_agent", return_value=mock_agent), + patch("src.client._build_middlewares", return_value=[]), + patch("src.client.apply_prompt_template", return_value="prompt"), + patch.object(client, "_get_tools", return_value=[]), + ): + client._ensure_agent(config) + + assert client._agent is mock_agent + + def test_reuses_agent_same_config(self, client): + """_ensure_agent does not recreate if config key unchanged.""" + mock_agent = MagicMock() + client._agent = mock_agent + client._agent_config_key = (None, True, False, False) + + config = client._get_runnable_config("t1") + client._ensure_agent(config) + + # Should still be the same mock — no recreation + assert client._agent is mock_agent + + +# --------------------------------------------------------------------------- +# get_model +# --------------------------------------------------------------------------- + +class TestGetModel: + def test_found(self, client): + model_cfg = MagicMock() + model_cfg.model_dump.return_value = {"name": "test-model"} + client._app_config.get_model_config.return_value = model_cfg + + result = client.get_model("test-model") + assert result == {"name": "test-model"} + + def test_not_found(self, client): + client._app_config.get_model_config.return_value = None + assert client.get_model("nonexistent") is None + + +# --------------------------------------------------------------------------- +# MCP config +# --------------------------------------------------------------------------- + +class TestMcpConfig: + def test_get_mcp_config(self, client): + server = MagicMock() + server.model_dump.return_value = {"enabled": True, "type": "stdio"} + ext_config = MagicMock() + ext_config.mcp_servers = {"github": server} + + with patch("src.client.get_extensions_config", return_value=ext_config): + result = client.get_mcp_config() + + assert "github" in result + assert result["github"]["enabled"] is True + + def test_update_mcp_config(self, client): + # Set up current config with skills + current_config = MagicMock() + current_config.skills = {} + + reloaded_server = MagicMock() + reloaded_server.model_dump.return_value = {"enabled": True, "type": "sse"} + reloaded_config = MagicMock() + reloaded_config.mcp_servers = {"new-server": reloaded_server} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump({}, f) + tmp_path = Path(f.name) + + try: + # Pre-set agent to verify it gets invalidated + client._agent = MagicMock() + + with ( + patch("src.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path), + patch("src.client.get_extensions_config", return_value=current_config), + patch("src.client.reload_extensions_config", return_value=reloaded_config), + ): + result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}}) + + assert "new-server" in result + assert client._agent is None # M2: agent invalidated + + # Verify file was actually written + with open(tmp_path) as f: + saved = json.load(f) + assert "mcpServers" in saved + finally: + tmp_path.unlink() + + +# --------------------------------------------------------------------------- +# Skills management +# --------------------------------------------------------------------------- + +class TestSkillsManagement: + def _make_skill(self, name="test-skill", enabled=True): + s = MagicMock() + s.name = name + s.description = "A test skill" + s.license = "MIT" + s.category = "public" + s.enabled = enabled + return s + + def test_get_skill_found(self, client): + skill = self._make_skill() + with patch("src.skills.loader.load_skills", return_value=[skill]): + result = client.get_skill("test-skill") + assert result is not None + assert result["name"] == "test-skill" + + def test_get_skill_not_found(self, client): + with patch("src.skills.loader.load_skills", return_value=[]): + result = client.get_skill("nonexistent") + assert result is None + + def test_update_skill(self, client): + skill = self._make_skill(enabled=True) + updated_skill = self._make_skill(enabled=False) + + ext_config = MagicMock() + ext_config.mcp_servers = {} + ext_config.skills = {} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump({}, f) + tmp_path = Path(f.name) + + try: + # Pre-set agent to verify it gets invalidated + client._agent = MagicMock() + + with ( + patch("src.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]), + patch("src.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path), + patch("src.client.get_extensions_config", return_value=ext_config), + patch("src.client.reload_extensions_config"), + ): + result = client.update_skill("test-skill", enabled=False) + assert result["enabled"] is False + assert client._agent is None # M2: agent invalidated + finally: + tmp_path.unlink() + + def test_update_skill_not_found(self, client): + with patch("src.skills.loader.load_skills", return_value=[]): + with pytest.raises(ValueError, match="not found"): + client.update_skill("nonexistent", enabled=True) + + def test_install_skill(self, client): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + + # Create a valid .skill archive + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: my-skill\ndescription: A skill\n---\nContent") + + archive_path = tmp_path / "my-skill.skill" + with zipfile.ZipFile(archive_path, "w") as zf: + zf.write(skill_dir / "SKILL.md", "my-skill/SKILL.md") + + skills_root = tmp_path / "skills" + (skills_root / "custom").mkdir(parents=True) + + with ( + patch("src.skills.loader.get_skills_root_path", return_value=skills_root), + patch("src.gateway.routers.skills._validate_skill_frontmatter", return_value=(True, "OK", "my-skill")), + ): + result = client.install_skill(archive_path) + + assert result["success"] is True + assert result["skill_name"] == "my-skill" + assert (skills_root / "custom" / "my-skill").exists() + + def test_install_skill_not_found(self, client): + with pytest.raises(FileNotFoundError): + client.install_skill("/nonexistent/path.skill") + + def test_install_skill_bad_extension(self, client): + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as f: + tmp_path = Path(f.name) + try: + with pytest.raises(ValueError, match=".skill extension"): + client.install_skill(tmp_path) + finally: + tmp_path.unlink() + + +# --------------------------------------------------------------------------- +# Memory management +# --------------------------------------------------------------------------- + +class TestMemoryManagement: + def test_reload_memory(self, client): + data = {"version": "1.0", "facts": []} + with patch("src.agents.memory.updater.reload_memory_data", return_value=data): + result = client.reload_memory() + assert result == data + + def test_get_memory_config(self, client): + config = MagicMock() + config.enabled = True + config.storage_path = ".deer-flow/memory.json" + config.debounce_seconds = 30 + config.max_facts = 100 + config.fact_confidence_threshold = 0.7 + config.injection_enabled = True + config.max_injection_tokens = 2000 + + with patch("src.config.memory_config.get_memory_config", return_value=config): + result = client.get_memory_config() + + assert result["enabled"] is True + assert result["max_facts"] == 100 + + def test_get_memory_status(self, client): + config = MagicMock() + config.enabled = True + config.storage_path = ".deer-flow/memory.json" + config.debounce_seconds = 30 + config.max_facts = 100 + config.fact_confidence_threshold = 0.7 + config.injection_enabled = True + config.max_injection_tokens = 2000 + + data = {"version": "1.0", "facts": []} + + with ( + patch("src.config.memory_config.get_memory_config", return_value=config), + patch("src.agents.memory.updater.get_memory_data", return_value=data), + ): + result = client.get_memory_status() + + assert "config" in result + assert "data" in result + + +# --------------------------------------------------------------------------- +# Uploads +# --------------------------------------------------------------------------- + +class TestUploads: + def test_upload_files(self, client): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + + # Create a source file + src_file = tmp_path / "test.txt" + src_file.write_text("hello") + + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + result = client.upload_files("thread-1", [src_file]) + + assert len(result) == 1 + assert result[0]["filename"] == "test.txt" + assert (uploads_dir / "test.txt").exists() + + def test_upload_files_not_found(self, client): + with pytest.raises(FileNotFoundError): + client.upload_files("thread-1", ["/nonexistent/file.txt"]) + + def test_list_uploads(self, client): + with tempfile.TemporaryDirectory() as tmp: + uploads_dir = Path(tmp) + (uploads_dir / "a.txt").write_text("a") + (uploads_dir / "b.txt").write_text("bb") + + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + result = client.list_uploads("thread-1") + + assert len(result) == 2 + names = {f["filename"] for f in result} + assert names == {"a.txt", "b.txt"} + + def test_delete_upload(self, client): + with tempfile.TemporaryDirectory() as tmp: + uploads_dir = Path(tmp) + (uploads_dir / "delete-me.txt").write_text("gone") + + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + client.delete_upload("thread-1", "delete-me.txt") + + assert not (uploads_dir / "delete-me.txt").exists() + + def test_delete_upload_not_found(self, client): + with tempfile.TemporaryDirectory() as tmp: + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=Path(tmp)): + with pytest.raises(FileNotFoundError): + client.delete_upload("thread-1", "nope.txt") + + def test_delete_upload_path_traversal(self, client): + with tempfile.TemporaryDirectory() as tmp: + uploads_dir = Path(tmp) + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + with pytest.raises(PermissionError): + client.delete_upload("thread-1", "../../etc/passwd") + + +# --------------------------------------------------------------------------- +# Artifacts +# --------------------------------------------------------------------------- + +class TestArtifacts: + def test_get_artifact(self, client): + with tempfile.TemporaryDirectory() as tmp: + user_data_dir = Path(tmp) / "user-data" + outputs = user_data_dir / "outputs" + outputs.mkdir(parents=True) + (outputs / "result.txt").write_text("artifact content") + + mock_paths = MagicMock() + mock_paths.sandbox_user_data_dir.return_value = user_data_dir + + with patch("src.client.get_paths", return_value=mock_paths): + content, mime = client.get_artifact("t1", "mnt/user-data/outputs/result.txt") + + assert content == b"artifact content" + assert "text" in mime + + def test_get_artifact_not_found(self, client): + with tempfile.TemporaryDirectory() as tmp: + user_data_dir = Path(tmp) / "user-data" + user_data_dir.mkdir() + + mock_paths = MagicMock() + mock_paths.sandbox_user_data_dir.return_value = user_data_dir + + with patch("src.client.get_paths", return_value=mock_paths): + with pytest.raises(FileNotFoundError): + client.get_artifact("t1", "mnt/user-data/outputs/nope.txt") + + def test_get_artifact_bad_prefix(self, client): + with pytest.raises(ValueError, match="must start with"): + client.get_artifact("t1", "bad/path/file.txt") + + def test_get_artifact_path_traversal(self, client): + with tempfile.TemporaryDirectory() as tmp: + user_data_dir = Path(tmp) / "user-data" + user_data_dir.mkdir() + + mock_paths = MagicMock() + mock_paths.sandbox_user_data_dir.return_value = user_data_dir + + with patch("src.client.get_paths", return_value=mock_paths): + with pytest.raises(PermissionError): + client.get_artifact("t1", "mnt/user-data/../../../etc/passwd") + + +# =========================================================================== +# Scenario-based integration tests +# =========================================================================== +# These tests simulate realistic user workflows end-to-end, exercising +# multiple methods in sequence to verify they compose correctly. + + +class TestScenarioMultiTurnConversation: + """Scenario: User has a multi-turn conversation within a single thread.""" + + def test_two_turn_conversation(self, client): + """Two sequential chat() calls on the same thread_id produce + independent results (without checkpointer, each call is stateless).""" + ai1 = AIMessage(content="I'm a helpful assistant.", id="ai-1") + ai2 = AIMessage(content="Python is great!", id="ai-2") + + agent = MagicMock() + agent.stream.side_effect = [ + iter([{"messages": [HumanMessage(content="who are you?", id="h-1"), ai1]}]), + iter([{"messages": [HumanMessage(content="what language?", id="h-2"), ai2]}]), + ] + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + r1 = client.chat("who are you?", thread_id="thread-multi") + r2 = client.chat("what language?", thread_id="thread-multi") + + assert r1 == "I'm a helpful assistant." + assert r2 == "Python is great!" + assert agent.stream.call_count == 2 + + def test_stream_collects_all_event_types_across_turns(self, client): + """A full turn with tool_call → tool_result → message → title → done.""" + ai_tc = AIMessage(content="", id="ai-1", tool_calls=[ + {"name": "web_search", "args": {"query": "LangGraph"}, "id": "tc-1"}, + ]) + tool_r = ToolMessage(content="LangGraph is a framework...", id="tm-1", tool_call_id="tc-1", name="web_search") + ai_final = AIMessage(content="LangGraph is a framework for building agents.", id="ai-2") + + chunks = [ + {"messages": [HumanMessage(content="search", id="h-1"), ai_tc]}, + {"messages": [HumanMessage(content="search", id="h-1"), ai_tc, tool_r]}, + {"messages": [HumanMessage(content="search", id="h-1"), ai_tc, tool_r, ai_final], "title": "LangGraph Search"}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("search", thread_id="t-full")) + + types = [e.type for e in events] + assert types == ["tool_call", "tool_result", "message", "title", "done"] + + # Verify event data integrity + tc_event = events[0] + assert tc_event.data["name"] == "web_search" + assert tc_event.data["args"] == {"query": "LangGraph"} + + tr_event = events[1] + assert tr_event.data["tool_call_id"] == "tc-1" + assert "LangGraph" in tr_event.data["content"] + + msg_event = events[2] + assert "framework" in msg_event.data["content"] + + title_event = events[3] + assert title_event.data["title"] == "LangGraph Search" + + +class TestScenarioToolChain: + """Scenario: Agent chains multiple tool calls in sequence.""" + + def test_multi_tool_chain(self, client): + """Agent calls bash → reads output → calls write_file → responds.""" + ai_bash = AIMessage(content="", id="ai-1", tool_calls=[ + {"name": "bash", "args": {"cmd": "ls /mnt/user-data/workspace"}, "id": "tc-1"}, + ]) + bash_result = ToolMessage(content="README.md\nsrc/", id="tm-1", tool_call_id="tc-1", name="bash") + ai_write = AIMessage(content="", id="ai-2", tool_calls=[ + {"name": "write_file", "args": {"path": "/mnt/user-data/outputs/listing.txt", "content": "README.md\nsrc/"}, "id": "tc-2"}, + ]) + write_result = ToolMessage(content="File written successfully.", id="tm-2", tool_call_id="tc-2", name="write_file") + ai_final = AIMessage(content="I listed the workspace and saved the output.", id="ai-3") + + chunks = [ + {"messages": [HumanMessage(content="list and save", id="h-1"), ai_bash]}, + {"messages": [HumanMessage(content="list and save", id="h-1"), ai_bash, bash_result]}, + {"messages": [HumanMessage(content="list and save", id="h-1"), ai_bash, bash_result, ai_write]}, + {"messages": [HumanMessage(content="list and save", id="h-1"), ai_bash, bash_result, ai_write, write_result]}, + {"messages": [HumanMessage(content="list and save", id="h-1"), ai_bash, bash_result, ai_write, write_result, ai_final]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("list and save", thread_id="t-chain")) + + tool_calls = [e for e in events if e.type == "tool_call"] + tool_results = [e for e in events if e.type == "tool_result"] + messages = [e for e in events if e.type == "message"] + + assert len(tool_calls) == 2 + assert tool_calls[0].data["name"] == "bash" + assert tool_calls[1].data["name"] == "write_file" + assert len(tool_results) == 2 + assert len(messages) == 1 + assert events[-1].type == "done" + + +class TestScenarioFileLifecycle: + """Scenario: Upload files → list them → use in chat → download artifact.""" + + def test_upload_list_delete_lifecycle(self, client): + """Upload → list → verify → delete → list again.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + + # Create source files + (tmp_path / "report.txt").write_text("quarterly report data") + (tmp_path / "data.csv").write_text("a,b,c\n1,2,3") + + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + # Step 1: Upload + uploaded = client.upload_files("t-lifecycle", [ + tmp_path / "report.txt", + tmp_path / "data.csv", + ]) + assert len(uploaded) == 2 + assert {f["filename"] for f in uploaded} == {"report.txt", "data.csv"} + + # Step 2: List + files = client.list_uploads("t-lifecycle") + assert len(files) == 2 + assert all("virtual_path" in f for f in files) + + # Step 3: Delete one + client.delete_upload("t-lifecycle", "report.txt") + + # Step 4: Verify deletion + files = client.list_uploads("t-lifecycle") + assert len(files) == 1 + assert files[0]["filename"] == "data.csv" + + def test_upload_then_read_artifact(self, client): + """Upload a file, simulate agent producing artifact, read it back.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + user_data_dir = tmp_path / "user-data" + outputs_dir = user_data_dir / "outputs" + outputs_dir.mkdir(parents=True) + + # Upload phase + src_file = tmp_path / "input.txt" + src_file.write_text("raw data to process") + + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + uploaded = client.upload_files("t-artifact", [src_file]) + assert len(uploaded) == 1 + + # Simulate agent writing an artifact + (outputs_dir / "analysis.json").write_text('{"result": "processed"}') + + # Retrieve artifact + mock_paths = MagicMock() + mock_paths.sandbox_user_data_dir.return_value = user_data_dir + + with patch("src.client.get_paths", return_value=mock_paths): + content, mime = client.get_artifact("t-artifact", "mnt/user-data/outputs/analysis.json") + + assert json.loads(content) == {"result": "processed"} + assert "json" in mime + + +class TestScenarioConfigManagement: + """Scenario: Query and update configuration through a management session.""" + + def test_model_and_skill_discovery(self, client): + """List models → get specific model → list skills → get specific skill.""" + # List models + models = client.list_models() + assert len(models) >= 1 + model_name = models[0]["name"] + + # Get specific model + model_cfg = MagicMock() + model_cfg.model_dump.return_value = {"name": model_name, "use": "langchain_openai:ChatOpenAI"} + client._app_config.get_model_config.return_value = model_cfg + detail = client.get_model(model_name) + assert detail["name"] == model_name + + # List skills + skill = MagicMock() + skill.name = "web-search" + skill.description = "Search the web" + skill.category = "public" + skill.enabled = True + + with patch("src.skills.loader.load_skills", return_value=[skill]): + skills = client.list_skills() + assert len(skills) == 1 + + # Get specific skill + with patch("src.skills.loader.load_skills", return_value=[skill]): + detail = client.get_skill("web-search") + assert detail is not None + assert detail["enabled"] is True + + def test_mcp_update_then_skill_toggle(self, client): + """Update MCP config → toggle skill → verify both invalidate agent.""" + with tempfile.TemporaryDirectory() as tmp: + config_file = Path(tmp) / "extensions_config.json" + config_file.write_text("{}") + + # --- MCP update --- + current_config = MagicMock() + current_config.skills = {} + + reloaded_server = MagicMock() + reloaded_server.model_dump.return_value = {"enabled": True, "type": "sse"} + reloaded_config = MagicMock() + reloaded_config.mcp_servers = {"my-mcp": reloaded_server} + + client._agent = MagicMock() # Simulate existing agent + with ( + patch("src.client.ExtensionsConfig.resolve_config_path", return_value=config_file), + patch("src.client.get_extensions_config", return_value=current_config), + patch("src.client.reload_extensions_config", return_value=reloaded_config), + ): + mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}}) + assert "my-mcp" in mcp_result + assert client._agent is None # Agent invalidated + + # --- Skill toggle --- + skill = MagicMock() + skill.name = "code-gen" + skill.description = "Generate code" + skill.license = "MIT" + skill.category = "custom" + skill.enabled = True + + toggled = MagicMock() + toggled.name = "code-gen" + toggled.description = "Generate code" + toggled.license = "MIT" + toggled.category = "custom" + toggled.enabled = False + + ext_config = MagicMock() + ext_config.mcp_servers = {} + ext_config.skills = {} + + client._agent = MagicMock() # Simulate re-created agent + with ( + patch("src.skills.loader.load_skills", side_effect=[[skill], [toggled]]), + patch("src.client.ExtensionsConfig.resolve_config_path", return_value=config_file), + patch("src.client.get_extensions_config", return_value=ext_config), + patch("src.client.reload_extensions_config"), + ): + skill_result = client.update_skill("code-gen", enabled=False) + assert skill_result["enabled"] is False + assert client._agent is None # Agent invalidated again + + +class TestScenarioAgentRecreation: + """Scenario: Config changes trigger agent recreation at the right times.""" + + def test_different_model_triggers_rebuild(self, client): + """Switching model_name between calls forces agent rebuild.""" + mock_agent_1 = MagicMock(name="agent-v1") + mock_agent_2 = MagicMock(name="agent-v2") + agents_created = [] + + def fake_create_agent(**kwargs): + agent = MagicMock() + agents_created.append(agent) + return agent + + config_a = client._get_runnable_config("t1", model_name="gpt-4") + config_b = client._get_runnable_config("t1", model_name="claude-3") + + with ( + patch("src.client.create_chat_model"), + patch("src.client.create_agent", side_effect=fake_create_agent), + patch("src.client._build_middlewares", return_value=[]), + patch("src.client.apply_prompt_template", return_value="prompt"), + patch.object(client, "_get_tools", return_value=[]), + ): + client._ensure_agent(config_a) + first_agent = client._agent + + client._ensure_agent(config_b) + second_agent = client._agent + + assert len(agents_created) == 2 + assert first_agent is not second_agent + + def test_same_config_reuses_agent(self, client): + """Repeated calls with identical config do not rebuild.""" + agents_created = [] + + def fake_create_agent(**kwargs): + agent = MagicMock() + agents_created.append(agent) + return agent + + config = client._get_runnable_config("t1", model_name="gpt-4") + + with ( + patch("src.client.create_chat_model"), + patch("src.client.create_agent", side_effect=fake_create_agent), + patch("src.client._build_middlewares", return_value=[]), + patch("src.client.apply_prompt_template", return_value="prompt"), + patch.object(client, "_get_tools", return_value=[]), + ): + client._ensure_agent(config) + client._ensure_agent(config) + client._ensure_agent(config) + + assert len(agents_created) == 1 + + def test_reset_agent_forces_rebuild(self, client): + """reset_agent() clears cache, next call rebuilds.""" + agents_created = [] + + def fake_create_agent(**kwargs): + agent = MagicMock() + agents_created.append(agent) + return agent + + config = client._get_runnable_config("t1") + + with ( + patch("src.client.create_chat_model"), + patch("src.client.create_agent", side_effect=fake_create_agent), + patch("src.client._build_middlewares", return_value=[]), + patch("src.client.apply_prompt_template", return_value="prompt"), + patch.object(client, "_get_tools", return_value=[]), + ): + client._ensure_agent(config) + client.reset_agent() + client._ensure_agent(config) + + assert len(agents_created) == 2 + + def test_per_call_override_triggers_rebuild(self, client): + """stream() with model_name override creates a different agent config.""" + ai = AIMessage(content="ok", id="ai-1") + agent = _make_agent_mock([{"messages": [ai]}]) + + agents_created = [] + + def fake_ensure(config): + key = tuple(config.get("configurable", {}).get(k) for k in ["model_name", "thinking_enabled", "is_plan_mode", "subagent_enabled"]) + agents_created.append(key) + client._agent = agent + + with patch.object(client, "_ensure_agent", side_effect=fake_ensure): + list(client.stream("hi", thread_id="t1")) + list(client.stream("hi", thread_id="t1", model_name="other-model")) + + # Two different config keys should have been created + assert len(agents_created) == 2 + assert agents_created[0] != agents_created[1] + + +class TestScenarioThreadIsolation: + """Scenario: Operations on different threads don't interfere.""" + + def test_uploads_isolated_per_thread(self, client): + """Files uploaded to thread-A are not visible in thread-B.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + uploads_a = tmp_path / "thread-a" / "uploads" + uploads_b = tmp_path / "thread-b" / "uploads" + uploads_a.mkdir(parents=True) + uploads_b.mkdir(parents=True) + + src_file = tmp_path / "secret.txt" + src_file.write_text("thread-a only") + + def get_dir(thread_id): + return uploads_a if thread_id == "thread-a" else uploads_b + + with patch.object(DeerFlowClient, "_get_uploads_dir", side_effect=get_dir): + client.upload_files("thread-a", [src_file]) + + files_a = client.list_uploads("thread-a") + files_b = client.list_uploads("thread-b") + + assert len(files_a) == 1 + assert len(files_b) == 0 + + def test_artifacts_isolated_per_thread(self, client): + """Artifacts in thread-A are not accessible from thread-B.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + + data_a = tmp_path / "thread-a" + data_b = tmp_path / "thread-b" + (data_a / "outputs").mkdir(parents=True) + (data_b / "outputs").mkdir(parents=True) + (data_a / "outputs" / "result.txt").write_text("thread-a artifact") + + mock_paths = MagicMock() + mock_paths.sandbox_user_data_dir.side_effect = lambda tid: data_a if tid == "thread-a" else data_b + + with patch("src.client.get_paths", return_value=mock_paths): + content, _ = client.get_artifact("thread-a", "mnt/user-data/outputs/result.txt") + assert content == b"thread-a artifact" + + with pytest.raises(FileNotFoundError): + client.get_artifact("thread-b", "mnt/user-data/outputs/result.txt") + + +class TestScenarioMemoryWorkflow: + """Scenario: Memory query → reload → status check.""" + + def test_memory_full_lifecycle(self, client): + """get_memory → reload → get_status covers the full memory API.""" + initial_data = {"version": "1.0", "facts": [{"id": "f1", "content": "User likes Python"}]} + updated_data = {"version": "1.0", "facts": [ + {"id": "f1", "content": "User likes Python"}, + {"id": "f2", "content": "User prefers dark mode"}, + ]} + + config = MagicMock() + config.enabled = True + config.storage_path = ".deer-flow/memory.json" + config.debounce_seconds = 30 + config.max_facts = 100 + config.fact_confidence_threshold = 0.7 + config.injection_enabled = True + config.max_injection_tokens = 2000 + + with patch("src.agents.memory.updater.get_memory_data", return_value=initial_data): + mem = client.get_memory() + assert len(mem["facts"]) == 1 + + with patch("src.agents.memory.updater.reload_memory_data", return_value=updated_data): + refreshed = client.reload_memory() + assert len(refreshed["facts"]) == 2 + + with ( + patch("src.config.memory_config.get_memory_config", return_value=config), + patch("src.agents.memory.updater.get_memory_data", return_value=updated_data), + ): + status = client.get_memory_status() + assert status["config"]["enabled"] is True + assert len(status["data"]["facts"]) == 2 + + +class TestScenarioSkillInstallAndUse: + """Scenario: Install a skill → verify it appears → toggle it.""" + + def test_install_then_toggle(self, client): + """Install .skill archive → list to verify → disable → verify disabled.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + + # Create .skill archive + skill_src = tmp_path / "my-analyzer" + skill_src.mkdir() + (skill_src / "SKILL.md").write_text( + "---\nname: my-analyzer\ndescription: Analyze code\nlicense: MIT\n---\nAnalysis skill" + ) + archive = tmp_path / "my-analyzer.skill" + with zipfile.ZipFile(archive, "w") as zf: + zf.write(skill_src / "SKILL.md", "my-analyzer/SKILL.md") + + skills_root = tmp_path / "skills" + (skills_root / "custom").mkdir(parents=True) + + # Step 1: Install + with ( + patch("src.skills.loader.get_skills_root_path", return_value=skills_root), + patch("src.gateway.routers.skills._validate_skill_frontmatter", return_value=(True, "OK", "my-analyzer")), + ): + result = client.install_skill(archive) + assert result["success"] is True + assert (skills_root / "custom" / "my-analyzer" / "SKILL.md").exists() + + # Step 2: List and find it + installed_skill = MagicMock() + installed_skill.name = "my-analyzer" + installed_skill.description = "Analyze code" + installed_skill.category = "custom" + installed_skill.enabled = True + + with patch("src.skills.loader.load_skills", return_value=[installed_skill]): + skills = client.list_skills() + assert any(s["name"] == "my-analyzer" for s in skills) + + # Step 3: Disable it + disabled_skill = MagicMock() + disabled_skill.name = "my-analyzer" + disabled_skill.description = "Analyze code" + disabled_skill.license = "MIT" + disabled_skill.category = "custom" + disabled_skill.enabled = False + + ext_config = MagicMock() + ext_config.mcp_servers = {} + ext_config.skills = {} + + config_file = tmp_path / "extensions_config.json" + config_file.write_text("{}") + + with ( + patch("src.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]), + patch("src.client.ExtensionsConfig.resolve_config_path", return_value=config_file), + patch("src.client.get_extensions_config", return_value=ext_config), + patch("src.client.reload_extensions_config"), + ): + toggled = client.update_skill("my-analyzer", enabled=False) + assert toggled["enabled"] is False + + +class TestScenarioEdgeCases: + """Scenario: Edge cases and error boundaries in realistic workflows.""" + + def test_empty_stream_response(self, client): + """Agent produces no messages — only done event.""" + agent = _make_agent_mock([{"messages": []}]) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t-empty")) + + assert len(events) == 1 + assert events[0].type == "done" + + def test_chat_on_empty_response(self, client): + """chat() returns empty string for no-message response.""" + agent = _make_agent_mock([{"messages": []}]) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + result = client.chat("hi", thread_id="t-empty-chat") + + assert result == "" + + def test_multiple_title_changes(self, client): + """Only distinct title changes produce events.""" + ai = AIMessage(content="ok", id="ai-1") + chunks = [ + {"messages": [ai], "title": "First Title"}, + {"messages": [], "title": "First Title"}, # same — should NOT emit + {"messages": [], "title": "Second Title"}, # different — should emit + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t-titles")) + + title_events = [e for e in events if e.type == "title"] + assert len(title_events) == 2 + assert title_events[0].data["title"] == "First Title" + assert title_events[1].data["title"] == "Second Title" + + def test_concurrent_tool_calls_in_single_message(self, client): + """Agent produces multiple tool_calls in one AIMessage.""" + ai = AIMessage(content="", id="ai-1", tool_calls=[ + {"name": "web_search", "args": {"q": "a"}, "id": "tc-1"}, + {"name": "web_search", "args": {"q": "b"}, "id": "tc-2"}, + {"name": "bash", "args": {"cmd": "echo hi"}, "id": "tc-3"}, + ]) + chunks = [{"messages": [ai]}] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("do things", thread_id="t-parallel")) + + tc_events = [e for e in events if e.type == "tool_call"] + assert len(tc_events) == 3 + assert {e.data["id"] for e in tc_events} == {"tc-1", "tc-2", "tc-3"} + + def test_upload_convertible_file_conversion_failure(self, client): + """Upload a .pdf file where conversion fails — file still uploaded, no markdown.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + + pdf_file = tmp_path / "doc.pdf" + pdf_file.write_bytes(b"%PDF-1.4 fake content") + + with ( + patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir), + patch("src.gateway.routers.uploads.CONVERTIBLE_EXTENSIONS", {".pdf"}), + patch("src.gateway.routers.uploads.convert_file_to_markdown", side_effect=Exception("conversion failed")), + ): + results = client.upload_files("t-pdf-fail", [pdf_file]) + + assert len(results) == 1 + assert results[0]["filename"] == "doc.pdf" + assert "markdown_file" not in results[0] # Conversion failed gracefully + assert (uploads_dir / "doc.pdf").exists() # File still uploaded diff --git a/backend/tests/test_client_live.py b/backend/tests/test_client_live.py new file mode 100644 index 0000000..b1d8d05 --- /dev/null +++ b/backend/tests/test_client_live.py @@ -0,0 +1,312 @@ +"""Live integration tests for DeerFlowClient with real API. + +These tests require a working config.yaml with valid API credentials. +They are skipped in CI and must be run explicitly: + + PYTHONPATH=. uv run pytest tests/test_client_live.py -v -s +""" + +import json +import os +import tempfile +from pathlib import Path + +import pytest + +# Skip entire module in CI or when no config.yaml exists +_skip_reason = None +if os.environ.get("CI"): + _skip_reason = "Live tests skipped in CI" +elif not Path(__file__).resolve().parents[2].joinpath("config.yaml").exists(): + _skip_reason = "No config.yaml found — live tests require valid API credentials" + +if _skip_reason: + pytest.skip(_skip_reason, allow_module_level=True) + +from src.client import DeerFlowClient, StreamEvent + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def client(): + """Create a real DeerFlowClient (no mocks).""" + return DeerFlowClient(thinking_enabled=False) + + +@pytest.fixture +def thread_tmp(tmp_path): + """Provide a unique thread_id + tmp directory for file operations.""" + import uuid + tid = f"live-test-{uuid.uuid4().hex[:8]}" + return tid, tmp_path + + +# =========================================================================== +# Scenario 1: Basic chat — model responds coherently +# =========================================================================== + +class TestLiveBasicChat: + def test_chat_returns_nonempty_string(self, client): + """chat() returns a non-empty response from the real model.""" + response = client.chat("Reply with exactly: HELLO") + assert isinstance(response, str) + assert len(response) > 0 + print(f" chat response: {response}") + + def test_chat_follows_instruction(self, client): + """Model can follow a simple instruction.""" + response = client.chat("What is 7 * 8? Reply with just the number.") + assert "56" in response + print(f" math response: {response}") + + +# =========================================================================== +# Scenario 2: Streaming — events arrive in correct order +# =========================================================================== + +class TestLiveStreaming: + def test_stream_yields_message_and_done(self, client): + """stream() produces at least one message event and ends with done.""" + events = list(client.stream("Say hi in one word.")) + + types = [e.type for e in events] + assert "message" in types, f"Expected 'message' event, got: {types}" + assert types[-1] == "done" + + for e in events: + assert isinstance(e, StreamEvent) + print(f" [{e.type}] {e.data}") + + def test_stream_message_content_nonempty(self, client): + """Streamed message events contain non-empty content.""" + messages = [ + e for e in client.stream("What color is the sky? One word.") + if e.type == "message" + ] + assert len(messages) >= 1 + for m in messages: + assert len(m.data.get("content", "")) > 0 + + +# =========================================================================== +# Scenario 3: Tool use — agent calls a tool and returns result +# =========================================================================== + +class TestLiveToolUse: + def test_agent_uses_bash_tool(self, client): + """Agent uses bash tool when asked to run a command.""" + events = list(client.stream( + "Use the bash tool to run: echo 'LIVE_TEST_OK'. " + "Then tell me the output." + )) + + types = [e.type for e in events] + print(f" event types: {types}") + for e in events: + print(f" [{e.type}] {e.data}") + + # Should have tool_call + tool_result + message + assert "tool_call" in types, f"Expected tool_call, got: {types}" + assert "tool_result" in types, f"Expected tool_result, got: {types}" + assert "message" in types + + tc = next(e for e in events if e.type == "tool_call") + assert tc.data["name"] == "bash" + + tr = next(e for e in events if e.type == "tool_result") + assert "LIVE_TEST_OK" in tr.data["content"] + + def test_agent_uses_ls_tool(self, client): + """Agent uses ls tool to list a directory.""" + events = list(client.stream( + "Use the ls tool to list the contents of /mnt/user-data/workspace. " + "Just report what you see." + )) + + types = [e.type for e in events] + print(f" event types: {types}") + + assert "tool_call" in types + tc = next(e for e in events if e.type == "tool_call") + assert tc.data["name"] == "ls" + + +# =========================================================================== +# Scenario 4: Multi-tool chain — agent chains tools in sequence +# =========================================================================== + +class TestLiveMultiToolChain: + def test_write_then_read(self, client): + """Agent writes a file, then reads it back.""" + events = list(client.stream( + "Step 1: Use write_file to write 'integration_test_content' to " + "/mnt/user-data/outputs/live_test.txt. " + "Step 2: Use read_file to read that file back. " + "Step 3: Tell me the content you read." + )) + + types = [e.type for e in events] + print(f" event types: {types}") + for e in events: + print(f" [{e.type}] {e.data}") + + tool_calls = [e for e in events if e.type == "tool_call"] + tool_names = [tc.data["name"] for tc in tool_calls] + + assert "write_file" in tool_names, f"Expected write_file, got: {tool_names}" + assert "read_file" in tool_names, f"Expected read_file, got: {tool_names}" + + # Final message should mention the content + messages = [e for e in events if e.type == "message"] + final_text = messages[-1].data["content"] if messages else "" + assert "integration_test_content" in final_text.lower() or any( + "integration_test_content" in e.data.get("content", "") + for e in events if e.type == "tool_result" + ) + + +# =========================================================================== +# Scenario 5: File upload lifecycle with real filesystem +# =========================================================================== + +class TestLiveFileUpload: + def test_upload_list_delete(self, client, thread_tmp): + """Upload → list → delete → verify deletion.""" + thread_id, tmp_path = thread_tmp + + # Create test files + f1 = tmp_path / "test_upload_a.txt" + f1.write_text("content A") + f2 = tmp_path / "test_upload_b.txt" + f2.write_text("content B") + + # Upload + results = client.upload_files(thread_id, [f1, f2]) + assert len(results) == 2 + filenames = {r["filename"] for r in results} + assert filenames == {"test_upload_a.txt", "test_upload_b.txt"} + for r in results: + assert r["size"] > 0 + assert r["virtual_path"].startswith("/mnt/user-data/uploads/") + print(f" uploaded: {filenames}") + + # List + listed = client.list_uploads(thread_id) + assert len(listed) == 2 + print(f" listed: {[f['filename'] for f in listed]}") + + # Delete one + client.delete_upload(thread_id, "test_upload_a.txt") + remaining = client.list_uploads(thread_id) + assert len(remaining) == 1 + assert remaining[0]["filename"] == "test_upload_b.txt" + print(f" after delete: {[f['filename'] for f in remaining]}") + + # Delete the other + client.delete_upload(thread_id, "test_upload_b.txt") + assert client.list_uploads(thread_id) == [] + + def test_upload_nonexistent_file_raises(self, client): + with pytest.raises(FileNotFoundError): + client.upload_files("t-fail", ["/nonexistent/path/file.txt"]) + + +# =========================================================================== +# Scenario 6: Configuration query — real config loading +# =========================================================================== + +class TestLiveConfigQueries: + def test_list_models_returns_ark(self, client): + """list_models() returns the configured ARK model.""" + models = client.list_models() + assert len(models) >= 1 + names = [m["name"] for m in models] + assert "ark-model" in names + print(f" models: {names}") + + def test_get_model_found(self, client): + """get_model() returns details for existing model.""" + model = client.get_model("ark-model") + assert model is not None + assert model["name"] == "ark-model" + print(f" model detail: {model}") + + def test_get_model_not_found(self, client): + assert client.get_model("nonexistent-model-xyz") is None + + def test_list_skills(self, client): + """list_skills() runs without error.""" + skills = client.list_skills() + assert isinstance(skills, list) + print(f" skills count: {len(skills)}") + for s in skills[:3]: + print(f" - {s['name']}: {s['enabled']}") + + +# =========================================================================== +# Scenario 7: Artifact read after agent writes +# =========================================================================== + +class TestLiveArtifact: + def test_get_artifact_after_write(self, client): + """Agent writes a file → client reads it back via get_artifact().""" + import uuid + thread_id = f"live-artifact-{uuid.uuid4().hex[:8]}" + + # Ask agent to write a file + events = list(client.stream( + "Use write_file to create /mnt/user-data/outputs/artifact_test.json " + "with content: {\"status\": \"ok\", \"source\": \"live_test\"}", + thread_id=thread_id, + )) + + # Verify write happened + tool_calls = [e for e in events if e.type == "tool_call"] + assert any(tc.data["name"] == "write_file" for tc in tool_calls) + + # Read artifact + content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json") + data = json.loads(content) + assert data["status"] == "ok" + assert data["source"] == "live_test" + assert "json" in mime + print(f" artifact: {data}, mime: {mime}") + + def test_get_artifact_not_found(self, client): + with pytest.raises(FileNotFoundError): + client.get_artifact("nonexistent-thread", "mnt/user-data/outputs/nope.txt") + + +# =========================================================================== +# Scenario 8: Per-call overrides +# =========================================================================== + +class TestLiveOverrides: + def test_thinking_disabled_still_works(self, client): + """Explicit thinking_enabled=False override produces a response.""" + response = client.chat( + "Say OK.", thinking_enabled=False, + ) + assert len(response) > 0 + print(f" response: {response}") + + +# =========================================================================== +# Scenario 9: Error resilience +# =========================================================================== + +class TestLiveErrorResilience: + def test_delete_nonexistent_upload(self, client): + with pytest.raises(FileNotFoundError): + client.delete_upload("nonexistent-thread", "ghost.txt") + + def test_bad_artifact_path(self, client): + with pytest.raises(ValueError): + client.get_artifact("t", "invalid/path") + + def test_path_traversal_blocked(self, client): + with pytest.raises(PermissionError): + client.delete_upload("t", "../../etc/passwd")