diff --git a/README.md b/README.md index 91a6cd3..f9707d3 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ DeerFlow is model-agnostic — it works with any LLM that implements the OpenAI- ## Embedded Python Client -DeerFlow can be used as an embedded Python library without running the full HTTP services. The `DeerFlowClient` provides direct in-process access to all agent and Gateway capabilities: +DeerFlow can be used as an embedded Python library without running the full HTTP services. The `DeerFlowClient` provides direct in-process access to all agent and Gateway capabilities, returning the same response schemas as the HTTP Gateway API: ```python from src.client import DeerFlowClient @@ -248,18 +248,19 @@ client = DeerFlowClient() # Chat response = client.chat("Analyze this paper for me", thread_id="my-thread") -# Streaming +# Streaming (LangGraph SSE protocol: values, messages-tuple, end) for event in client.stream("hello"): - print(event.type, event.data) + if event.type == "messages-tuple" and event.data.get("type") == "ai": + print(event.data["content"]) -# Configuration & management -print(client.list_models()) -print(client.list_skills()) +# Configuration & management — returns Gateway-aligned dicts +models = client.list_models() # {"models": [...]} +skills = client.list_skills() # {"skills": [...]} client.update_skill("web-search", enabled=True) -client.upload_files("thread-1", ["./report.pdf"]) +client.upload_files("thread-1", ["./report.pdf"]) # {"success": True, "files": [...]} ``` -See `backend/src/client.py` for full API documentation. +All dict-returning methods are validated against Gateway Pydantic response models in CI (`TestGatewayConformance`), ensuring the embedded client stays in sync with the HTTP API schemas. See `backend/src/client.py` for full API documentation. ## Documentation diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 1e55289..31b1728 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -294,31 +294,36 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me ### Embedded Client (`src/client.py`) -`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. +`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes. **Architecture**: Imports the same `src/` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency. **Agent Conversation** (replaces LangGraph Server): - `chat(message, thread_id)` — synchronous, returns final text -- `stream(message, thread_id)` — yields `StreamEvent` (message, tool_call, tool_result, title, done) +- `stream(message, thread_id)` — yields `StreamEvent` aligned with LangGraph SSE protocol: + - `"values"` — full state snapshot (title, messages, artifacts) + - `"messages-tuple"` — per-message update (AI text, tool calls, tool results) + - `"end"` — stream finished - Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent` - Supports `checkpointer` parameter for state persistence across turns -- Invocation pattern: `agent.stream(state, config, context, stream_mode="values")` +- `reset_agent()` forces agent recreation (e.g. after memory or skill changes) **Gateway Equivalent Methods** (replaces Gateway API): -| Category | Methods | -|----------|---------| -| Models | `list_models()`, `get_model(name)` | -| MCP | `get_mcp_config()`, `update_mcp_config(servers)` | -| Skills | `list_skills()`, `get_skill(name)`, `update_skill(name, enabled)`, `install_skill(path)` | -| Memory | `get_memory()`, `reload_memory()`, `get_memory_config()`, `get_memory_status()` | -| Uploads | `upload_files(thread_id, files)`, `list_uploads(thread_id)`, `delete_upload(thread_id, filename)` | -| Artifacts | `get_artifact(thread_id, path)` → `(bytes, mime_type)` | +| Category | Methods | Return format | +|----------|---------|---------------| +| Models | `list_models()`, `get_model(name)` | `{"models": [...]}`, `{name, display_name, ...}` | +| MCP | `get_mcp_config()`, `update_mcp_config(servers)` | `{"mcp_servers": {...}}` | +| Skills | `list_skills()`, `get_skill(name)`, `update_skill(name, enabled)`, `install_skill(path)` | `{"skills": [...]}` | +| Memory | `get_memory()`, `reload_memory()`, `get_memory_config()`, `get_memory_status()` | dict | +| Uploads | `upload_files(thread_id, files)`, `list_uploads(thread_id)`, `delete_upload(thread_id, filename)` | `{"success": true, "files": [...]}`, `{"files": [...], "count": N}` | +| Artifacts | `get_artifact(thread_id, path)` → `(bytes, mime_type)` | tuple | -**Key difference from Gateway**: Upload accepts local `Path` objects instead of HTTP `UploadFile`. Artifact returns `(bytes, mime_type)` instead of HTTP Response. +**Key difference from Gateway**: Upload accepts local `Path` objects instead of HTTP `UploadFile`. Artifact returns `(bytes, mime_type)` instead of HTTP Response. `update_mcp_config()` and `update_skill()` automatically invalidate the cached agent. -**Tests**: `tests/test_client.py` (45 unit tests) +**Tests**: `tests/test_client.py` (77 unit tests including `TestGatewayConformance`), `tests/test_client_live.py` (live integration tests, requires config.yaml) + +**Gateway Conformance Tests** (`TestGatewayConformance`): Validate that every dict-returning client method conforms to the corresponding Gateway Pydantic response model. Each test parses the client output through the Gateway model — if Gateway adds a required field that the client doesn't provide, Pydantic raises `ValidationError` and CI catches the drift. Covers: `ModelsListResponse`, `ModelResponse`, `SkillsListResponse`, `SkillResponse`, `SkillInstallResponse`, `McpConfigResponse`, `UploadResponse`, `MemoryConfigResponse`, `MemoryStatusResponse`. ## Development Workflow diff --git a/backend/src/client.py b/backend/src/client.py index a859c76..6639eab 100644 --- a/backend/src/client.py +++ b/backend/src/client.py @@ -30,7 +30,7 @@ from pathlib import Path from typing import Any from langchain.agents import create_agent -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig from src.agents.lead_agent.agent import _build_middlewares @@ -38,8 +38,8 @@ from src.agents.lead_agent.prompt import apply_prompt_template from src.agents.thread_state import ThreadState from src.config.app_config import get_app_config, reload_app_config from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config -from src.models import create_chat_model from src.config.paths import get_paths +from src.models import create_chat_model logger = logging.getLogger(__name__) @@ -48,8 +48,13 @@ logger = logging.getLogger(__name__) class StreamEvent: """A single event from the streaming agent response. + Event types align with the LangGraph SSE protocol: + - ``"values"``: Full state snapshot (title, messages, artifacts). + - ``"messages-tuple"``: Per-message update (AI text, tool calls, tool results). + - ``"end"``: Stream finished. + Attributes: - type: Event type — "message", "tool_call", "tool_result", "title", or "done". + type: Event type. data: Event payload. Contents vary by type. """ @@ -214,6 +219,28 @@ class DeerFlowClient: return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + @staticmethod + def _serialize_message(msg) -> dict: + """Serialize a LangChain message to a plain dict for values events.""" + if isinstance(msg, AIMessage): + d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)} + if msg.tool_calls: + d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls] + return d + if isinstance(msg, ToolMessage): + return { + "type": "tool", + "content": msg.content if isinstance(msg.content, str) else str(msg.content), + "name": getattr(msg, "name", None), + "tool_call_id": getattr(msg, "tool_call_id", None), + "id": getattr(msg, "id", None), + } + if isinstance(msg, HumanMessage): + return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)} + if isinstance(msg, SystemMessage): + return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)} + return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)} + @staticmethod def _extract_text(content) -> str: """Extract plain text from AIMessage content (str or list of blocks).""" @@ -246,6 +273,10 @@ class DeerFlowClient: finishes its turn. A ``checkpointer`` must be provided at init time for multi-turn context to be preserved across calls. + Event types align with the LangGraph SSE protocol so that + consumers can switch between HTTP streaming and embedded mode + without changing their event-handling logic. + Args: message: User message text. thread_id: Thread ID for conversation context. Auto-generated if None. @@ -254,11 +285,11 @@ class DeerFlowClient: Yields: StreamEvent with one of: - - type="message" data={"content": str} - - type="tool_call" data={"name": str, "args": dict, "id": str} - - type="tool_result" data={"name": str, "content": str, "tool_call_id": str} - - type="title" data={"title": str} - - type="done" data={} + - type="values" data={"title": str|None, "messages": [...], "artifacts": [...]} + - type="messages-tuple" data={"type": "ai", "content": str, "id": str} + - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]} + - type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str} + - type="end" data={} """ if thread_id is None: thread_id = str(uuid.uuid4()) @@ -270,7 +301,6 @@ class DeerFlowClient: context = {"thread_id": thread_id} seen_ids: set[str] = set() - last_title: str | None = None for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"): messages = chunk.get("messages", []) @@ -284,41 +314,57 @@ class DeerFlowClient: if isinstance(msg, AIMessage): if msg.tool_calls: - for tc in msg.tool_calls: - yield StreamEvent( - type="tool_call", - data={"name": tc["name"], "args": tc["args"], "id": tc.get("id")}, - ) + yield StreamEvent( + type="messages-tuple", + data={ + "type": "ai", + "content": "", + "id": msg_id, + "tool_calls": [ + {"name": tc["name"], "args": tc["args"], "id": tc.get("id")} + for tc in msg.tool_calls + ], + }, + ) text = self._extract_text(msg.content) if text: - yield StreamEvent(type="message", data={"content": text}) + yield StreamEvent( + type="messages-tuple", + data={"type": "ai", "content": text, "id": msg_id}, + ) elif isinstance(msg, ToolMessage): yield StreamEvent( - type="tool_result", + type="messages-tuple", data={ - "name": getattr(msg, "name", None), + "type": "tool", "content": msg.content if isinstance(msg.content, str) else str(msg.content), + "name": getattr(msg, "name", None), "tool_call_id": getattr(msg, "tool_call_id", None), + "id": msg_id, }, ) - # Title changes - title = chunk.get("title") - if title and title != last_title: - last_title = title - yield StreamEvent(type="title", data={"title": title}) + # Emit a values event for each state snapshot + yield StreamEvent( + type="values", + data={ + "title": chunk.get("title"), + "messages": [self._serialize_message(m) for m in messages], + "artifacts": chunk.get("artifacts", []), + }, + ) - yield StreamEvent(type="done", data={}) + yield StreamEvent(type="end", data={}) def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str: """Send a message and return the final text response. Convenience wrapper around :meth:`stream` that returns only the - **last** ``message`` event's text. If the agent emits multiple - message segments in one turn, intermediate segments are discarded. - Use :meth:`stream` directly to capture all events. + **last** AI text from ``messages-tuple`` events. If the agent emits + multiple text segments in one turn, intermediate segments are + discarded. Use :meth:`stream` directly to capture all events. Args: message: User message text. @@ -330,42 +376,59 @@ class DeerFlowClient: """ last_text = "" for event in self.stream(message, thread_id=thread_id, **kwargs): - if event.type == "message": - last_text = event.data.get("content", "") + if event.type == "messages-tuple" and event.data.get("type") == "ai": + content = event.data.get("content", "") + if content: + last_text = content return last_text # ------------------------------------------------------------------ # Public API — configuration queries # ------------------------------------------------------------------ - def list_models(self) -> list[dict]: + def list_models(self) -> dict: """List available models from configuration. Returns: - List of model config dicts. + Dict with "models" key containing list of model info dicts, + matching the Gateway API ``ModelsListResponse`` schema. """ - return [model.model_dump() for model in self._app_config.models] + return { + "models": [ + { + "name": model.name, + "display_name": getattr(model, "display_name", None), + "description": getattr(model, "description", None), + "supports_thinking": getattr(model, "supports_thinking", False), + } + for model in self._app_config.models + ] + } - def list_skills(self, enabled_only: bool = False) -> list[dict]: + def list_skills(self, enabled_only: bool = False) -> dict: """List available skills. Args: enabled_only: If True, only return enabled skills. Returns: - List of skill info dicts with name, description, category, enabled. + Dict with "skills" key containing list of skill info dicts, + matching the Gateway API ``SkillsListResponse`` schema. """ from src.skills.loader import load_skills - return [ - { - "name": s.name, - "description": s.description, - "category": s.category, - "enabled": s.enabled, - } - for s in load_skills(enabled_only=enabled_only) - ] + return { + "skills": [ + { + "name": s.name, + "description": s.description, + "license": s.license, + "category": s.category, + "enabled": s.enabled, + } + for s in load_skills(enabled_only=enabled_only) + ] + } def get_memory(self) -> dict: """Get current memory data. @@ -384,25 +447,34 @@ class DeerFlowClient: name: Model name. Returns: - Model config dict, or None if not found. + Model info dict matching the Gateway API ``ModelResponse`` + schema, or None if not found. """ model = self._app_config.get_model_config(name) - return model.model_dump() if model is not None else None + if model is None: + return None + return { + "name": model.name, + "display_name": getattr(model, "display_name", None), + "description": getattr(model, "description", None), + "supports_thinking": getattr(model, "supports_thinking", False), + } # ------------------------------------------------------------------ # Public API — MCP configuration # ------------------------------------------------------------------ - def get_mcp_config(self) -> dict[str, dict]: + def get_mcp_config(self) -> dict: """Get MCP server configurations. Returns: - Dict mapping server name to its config dict. + Dict with "mcp_servers" key mapping server name to config, + matching the Gateway API ``McpConfigResponse`` schema. """ config = get_extensions_config() - return {name: server.model_dump() for name, server in config.mcp_servers.items()} + return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}} - def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict[str, dict]: + def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict: """Update MCP server configurations. Writes to extensions_config.json and reloads the cache. @@ -412,7 +484,8 @@ class DeerFlowClient: Each value should contain keys like enabled, type, command, args, env, url, etc. Returns: - The updated MCP config. + Dict with "mcp_servers" key, matching the Gateway API + ``McpConfigResponse`` schema. Raises: OSError: If the config file cannot be written. @@ -421,7 +494,7 @@ class DeerFlowClient: if config_path is None: raise FileNotFoundError( "Cannot locate extensions_config.json. " - "Pass config_path to DeerFlowClient or set DEER_FLOW_HOME." + "Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root." ) current_config = get_extensions_config() @@ -435,7 +508,7 @@ class DeerFlowClient: self._agent = None reloaded = reload_extensions_config() - return {name: server.model_dump() for name, server in reloaded.mcp_servers.items()} + return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}} # ------------------------------------------------------------------ # Public API — skills management @@ -488,7 +561,7 @@ class DeerFlowClient: if config_path is None: raise FileNotFoundError( "Cannot locate extensions_config.json. " - "Pass config_path to DeerFlowClient or set DEER_FLOW_HOME." + "Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root." ) extensions_config = get_extensions_config() @@ -634,7 +707,7 @@ class DeerFlowClient: base.mkdir(parents=True, exist_ok=True) return base - def upload_files(self, thread_id: str, files: list[str | Path]) -> list[dict]: + def upload_files(self, thread_id: str, files: list[str | Path]) -> dict: """Upload local files into a thread's uploads directory. For PDF, PPT, Excel, and Word files, they are also converted to Markdown. @@ -644,7 +717,8 @@ class DeerFlowClient: files: List of local file paths to upload. Returns: - List of file info dicts (filename, size, path, virtual_path). + Dict with success, files, message — matching the Gateway API + ``UploadResponse`` schema. Raises: FileNotFoundError: If any file does not exist. @@ -660,7 +734,7 @@ class DeerFlowClient: resolved_files.append(p) uploads_dir = self._get_uploads_dir(thread_id) - results: list[dict] = [] + uploaded_files: list[dict] = [] for src_path in resolved_files: @@ -669,9 +743,10 @@ class DeerFlowClient: info: dict[str, Any] = { "filename": src_path.name, - "size": dest.stat().st_size, + "size": str(dest.stat().st_size), "path": str(dest), "virtual_path": f"/mnt/user-data/uploads/{src_path.name}", + "artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{src_path.name}", } if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS: @@ -690,23 +765,29 @@ class DeerFlowClient: if md_path is not None: info["markdown_file"] = md_path.name info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}" + info["markdown_artifact_url"] = f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{md_path.name}" - results.append(info) + uploaded_files.append(info) - return results + return { + "success": True, + "files": uploaded_files, + "message": f"Successfully uploaded {len(uploaded_files)} file(s)", + } - def list_uploads(self, thread_id: str) -> list[dict]: + def list_uploads(self, thread_id: str) -> dict: """List files in a thread's uploads directory. Args: thread_id: Thread ID. Returns: - List of file info dicts. + Dict with "files" and "count" keys, matching the Gateway API + ``list_uploaded_files`` response. """ uploads_dir = self._get_uploads_dir(thread_id) if not uploads_dir.exists(): - return [] + return {"files": [], "count": 0} files = [] for fp in sorted(uploads_dir.iterdir()): @@ -714,21 +795,26 @@ class DeerFlowClient: stat = fp.stat() files.append({ "filename": fp.name, - "size": stat.st_size, + "size": str(stat.st_size), "path": str(fp), "virtual_path": f"/mnt/user-data/uploads/{fp.name}", + "artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{fp.name}", "extension": fp.suffix, "modified": stat.st_mtime, }) - return files + return {"files": files, "count": len(files)} - def delete_upload(self, thread_id: str, filename: str) -> None: + def delete_upload(self, thread_id: str, filename: str) -> dict: """Delete a file from a thread's uploads directory. Args: thread_id: Thread ID. filename: Filename to delete. + Returns: + Dict with success and message, matching the Gateway API + ``delete_uploaded_file`` response. + Raises: FileNotFoundError: If the file does not exist. PermissionError: If path traversal is detected. @@ -738,13 +824,14 @@ class DeerFlowClient: try: file_path.relative_to(uploads_dir.resolve()) - except ValueError: - raise PermissionError("Access denied: path traversal detected") + except ValueError as exc: + raise PermissionError("Access denied: path traversal detected") from exc if not file_path.is_file(): raise FileNotFoundError(f"File not found: {filename}") file_path.unlink() + return {"success": True, "message": f"Deleted {filename}"} # ------------------------------------------------------------------ # Public API — artifacts @@ -775,8 +862,8 @@ class DeerFlowClient: try: actual.relative_to(base_dir.resolve()) - except ValueError: - raise PermissionError("Access denied: path traversal detected") + except ValueError as exc: + raise PermissionError("Access denied: path traversal detected") from exc if not actual.exists(): raise FileNotFoundError(f"Artifact not found: {path}") if not actual.is_file(): diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index 729a1fe..06ff9db 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -10,6 +10,11 @@ import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401 from src.client import DeerFlowClient +from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse +from src.gateway.routers.mcp import McpConfigResponse +from src.gateway.routers.models import ModelResponse, ModelsListResponse +from src.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse +from src.gateway.routers.uploads import UploadResponse # --------------------------------------------------------------------------- # Fixtures @@ -81,14 +86,19 @@ class TestClientInit: class TestConfigQueries: def test_list_models(self, client): - models = client.list_models() - assert len(models) == 1 - assert models[0]["name"] == "test-model" + result = client.list_models() + assert "models" in result + assert len(result["models"]) == 1 + assert result["models"][0]["name"] == "test-model" + # Verify Gateway-aligned fields are present + assert "display_name" in result["models"][0] + assert "supports_thinking" in result["models"][0] def test_list_skills(self, client): skill = MagicMock() skill.name = "web-search" skill.description = "Search the web" + skill.license = "MIT" skill.category = "public" skill.enabled = True @@ -96,10 +106,12 @@ class TestConfigQueries: result = client.list_skills() mock_load.assert_called_once_with(enabled_only=False) - assert len(result) == 1 - assert result[0] == { + assert "skills" in result + assert len(result["skills"]) == 1 + assert result["skills"][0] == { "name": "web-search", "description": "Search the web", + "license": "MIT", "category": "public", "enabled": True, } @@ -128,9 +140,24 @@ def _make_agent_mock(chunks: list[dict]): return agent +def _ai_events(events): + """Filter messages-tuple events with type=ai and non-empty content.""" + return [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")] + + +def _tool_call_events(events): + """Filter messages-tuple events with type=ai and tool_calls.""" + return [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data] + + +def _tool_result_events(events): + """Filter messages-tuple events with type=tool.""" + return [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"] + + class TestStream: def test_basic_message(self, client): - """stream() emits message + done for a simple AI reply.""" + """stream() emits messages-tuple + values + end for a simple AI reply.""" ai = AIMessage(content="Hello!", id="ai-1") chunks = [ {"messages": [HumanMessage(content="hi", id="h-1")]}, @@ -145,13 +172,14 @@ class TestStream: events = list(client.stream("hi", thread_id="t1")) types = [e.type for e in events] - assert "message" in types - assert types[-1] == "done" - msg_events = [e for e in events if e.type == "message"] + assert "messages-tuple" in types + assert "values" in types + assert types[-1] == "end" + msg_events = _ai_events(events) assert msg_events[0].data["content"] == "Hello!" def test_tool_call_and_result(self, client): - """stream() emits tool_call and tool_result events.""" + """stream() emits messages-tuple events for tool calls and results.""" ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}]) tool = ToolMessage(content="file.txt", id="tm-1", tool_call_id="tc-1", name="bash") ai2 = AIMessage(content="Here are the files.", id="ai-2") @@ -169,14 +197,13 @@ class TestStream: ): events = list(client.stream("list files", thread_id="t2")) - types = [e.type for e in events] - assert "tool_call" in types - assert "tool_result" in types - assert "message" in types - assert types[-1] == "done" + assert len(_tool_call_events(events)) >= 1 + assert len(_tool_result_events(events)) >= 1 + assert len(_ai_events(events)) >= 1 + assert events[-1].type == "end" - def test_title_event(self, client): - """stream() emits title event when title appears in state.""" + def test_values_event_with_title(self, client): + """stream() emits values event containing title when present in state.""" ai = AIMessage(content="ok", id="ai-1") chunks = [ {"messages": [HumanMessage(content="hi", id="h-1"), ai], "title": "Greeting"}, @@ -189,9 +216,10 @@ class TestStream: ): events = list(client.stream("hi", thread_id="t3")) - title_events = [e for e in events if e.type == "title"] - assert len(title_events) == 1 - assert title_events[0].data["title"] == "Greeting" + values_events = [e for e in events if e.type == "values"] + assert len(values_events) >= 1 + assert values_events[-1].data["title"] == "Greeting" + assert "messages" in values_events[-1].data def test_deduplication(self, client): """Messages with the same id are not emitted twice.""" @@ -208,7 +236,7 @@ class TestStream: ): events = list(client.stream("hi", thread_id="t4")) - msg_events = [e for e in events if e.type == "message"] + msg_events = _ai_events(events) assert len(msg_events) == 1 def test_auto_thread_id(self, client): @@ -221,8 +249,8 @@ class TestStream: ): events = list(client.stream("hi")) - # Should not raise; done event proves it completed - assert events[-1].type == "done" + # Should not raise; end event proves it completed + assert events[-1].type == "end" def test_list_content_blocks(self, client): """stream() handles AIMessage with list-of-blocks content.""" @@ -242,7 +270,7 @@ class TestStream: ): events = list(client.stream("hi", thread_id="t5")) - msg_events = [e for e in events if e.type == "message"] + msg_events = _ai_events(events) assert len(msg_events) == 1 assert msg_events[0].data["content"] == "result" @@ -347,11 +375,19 @@ class TestEnsureAgent: class TestGetModel: def test_found(self, client): model_cfg = MagicMock() - model_cfg.model_dump.return_value = {"name": "test-model"} + model_cfg.name = "test-model" + model_cfg.display_name = "Test Model" + model_cfg.description = "A test model" + model_cfg.supports_thinking = True client._app_config.get_model_config.return_value = model_cfg result = client.get_model("test-model") - assert result == {"name": "test-model"} + assert result == { + "name": "test-model", + "display_name": "Test Model", + "description": "A test model", + "supports_thinking": True, + } def test_not_found(self, client): client._app_config.get_model_config.return_value = None @@ -372,8 +408,9 @@ class TestMcpConfig: with patch("src.client.get_extensions_config", return_value=ext_config): result = client.get_mcp_config() - assert "github" in result - assert result["github"]["enabled"] is True + assert "mcp_servers" in result + assert "github" in result["mcp_servers"] + assert result["mcp_servers"]["github"]["enabled"] is True def test_update_mcp_config(self, client): # Set up current config with skills @@ -400,7 +437,8 @@ class TestMcpConfig: ): result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}}) - assert "new-server" in result + assert "mcp_servers" in result + assert "new-server" in result["mcp_servers"] assert client._agent is None # M2: agent invalidated # Verify file was actually written @@ -578,8 +616,11 @@ class TestUploads: with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): result = client.upload_files("thread-1", [src_file]) - assert len(result) == 1 - assert result[0]["filename"] == "test.txt" + assert result["success"] is True + assert len(result["files"]) == 1 + assert result["files"][0]["filename"] == "test.txt" + assert "artifact_url" in result["files"][0] + assert "message" in result assert (uploads_dir / "test.txt").exists() def test_upload_files_not_found(self, client): @@ -595,9 +636,13 @@ class TestUploads: with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): result = client.list_uploads("thread-1") - assert len(result) == 2 - names = {f["filename"] for f in result} + assert result["count"] == 2 + assert len(result["files"]) == 2 + names = {f["filename"] for f in result["files"]} assert names == {"a.txt", "b.txt"} + # Verify artifact_url is present + for f in result["files"]: + assert "artifact_url" in f def test_delete_upload(self, client): with tempfile.TemporaryDirectory() as tmp: @@ -605,8 +650,10 @@ class TestUploads: (uploads_dir / "delete-me.txt").write_text("gone") with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): - client.delete_upload("thread-1", "delete-me.txt") + result = client.delete_upload("thread-1", "delete-me.txt") + assert result["success"] is True + assert "delete-me.txt" in result["message"] assert not (uploads_dir / "delete-me.txt").exists() def test_delete_upload_not_found(self, client): @@ -707,7 +754,7 @@ class TestScenarioMultiTurnConversation: assert agent.stream.call_count == 2 def test_stream_collects_all_event_types_across_turns(self, client): - """A full turn with tool_call → tool_result → message → title → done.""" + """A full turn emits messages-tuple (tool_call, tool_result, ai text) + values + end.""" ai_tc = AIMessage(content="", id="ai-1", tool_calls=[ {"name": "web_search", "args": {"query": "LangGraph"}, "id": "tc-1"}, ]) @@ -727,23 +774,30 @@ class TestScenarioMultiTurnConversation: ): events = list(client.stream("search", thread_id="t-full")) - types = [e.type for e in events] - assert types == ["tool_call", "tool_result", "message", "title", "done"] + # Verify expected event types + types = set(e.type for e in events) + assert types == {"messages-tuple", "values", "end"} + assert events[-1].type == "end" - # Verify event data integrity - tc_event = events[0] - assert tc_event.data["name"] == "web_search" - assert tc_event.data["args"] == {"query": "LangGraph"} + # Verify tool_call data + tc_events = _tool_call_events(events) + assert len(tc_events) == 1 + assert tc_events[0].data["tool_calls"][0]["name"] == "web_search" + assert tc_events[0].data["tool_calls"][0]["args"] == {"query": "LangGraph"} - tr_event = events[1] - assert tr_event.data["tool_call_id"] == "tc-1" - assert "LangGraph" in tr_event.data["content"] + # Verify tool_result data + tr_events = _tool_result_events(events) + assert len(tr_events) == 1 + assert tr_events[0].data["tool_call_id"] == "tc-1" + assert "LangGraph" in tr_events[0].data["content"] - msg_event = events[2] - assert "framework" in msg_event.data["content"] + # Verify AI text + msg_events = _ai_events(events) + assert any("framework" in e.data["content"] for e in msg_events) - title_event = events[3] - assert title_event.data["title"] == "LangGraph Search" + # Verify values event contains title + values_events = [e for e in events if e.type == "values"] + assert any(e.data.get("title") == "LangGraph Search" for e in values_events) class TestScenarioToolChain: @@ -776,16 +830,16 @@ class TestScenarioToolChain: ): events = list(client.stream("list and save", thread_id="t-chain")) - tool_calls = [e for e in events if e.type == "tool_call"] - tool_results = [e for e in events if e.type == "tool_result"] - messages = [e for e in events if e.type == "message"] + tool_calls = _tool_call_events(events) + tool_results = _tool_result_events(events) + messages = _ai_events(events) assert len(tool_calls) == 2 - assert tool_calls[0].data["name"] == "bash" - assert tool_calls[1].data["name"] == "write_file" + assert tool_calls[0].data["tool_calls"][0]["name"] == "bash" + assert tool_calls[1].data["tool_calls"][0]["name"] == "write_file" assert len(tool_results) == 2 assert len(messages) == 1 - assert events[-1].type == "done" + assert events[-1].type == "end" class TestScenarioFileLifecycle: @@ -804,25 +858,27 @@ class TestScenarioFileLifecycle: with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): # Step 1: Upload - uploaded = client.upload_files("t-lifecycle", [ + result = client.upload_files("t-lifecycle", [ tmp_path / "report.txt", tmp_path / "data.csv", ]) - assert len(uploaded) == 2 - assert {f["filename"] for f in uploaded} == {"report.txt", "data.csv"} + assert result["success"] is True + assert len(result["files"]) == 2 + assert {f["filename"] for f in result["files"]} == {"report.txt", "data.csv"} # Step 2: List - files = client.list_uploads("t-lifecycle") - assert len(files) == 2 - assert all("virtual_path" in f for f in files) + listed = client.list_uploads("t-lifecycle") + assert listed["count"] == 2 + assert all("virtual_path" in f for f in listed["files"]) # Step 3: Delete one - client.delete_upload("t-lifecycle", "report.txt") + del_result = client.delete_upload("t-lifecycle", "report.txt") + assert del_result["success"] is True # Step 4: Verify deletion - files = client.list_uploads("t-lifecycle") - assert len(files) == 1 - assert files[0]["filename"] == "data.csv" + listed = client.list_uploads("t-lifecycle") + assert listed["count"] == 1 + assert listed["files"][0]["filename"] == "data.csv" def test_upload_then_read_artifact(self, client): """Upload a file, simulate agent producing artifact, read it back.""" @@ -840,7 +896,7 @@ class TestScenarioFileLifecycle: with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): uploaded = client.upload_files("t-artifact", [src_file]) - assert len(uploaded) == 1 + assert len(uploaded["files"]) == 1 # Simulate agent writing an artifact (outputs_dir / "analysis.json").write_text('{"result": "processed"}') @@ -862,13 +918,16 @@ class TestScenarioConfigManagement: def test_model_and_skill_discovery(self, client): """List models → get specific model → list skills → get specific skill.""" # List models - models = client.list_models() - assert len(models) >= 1 - model_name = models[0]["name"] + result = client.list_models() + assert len(result["models"]) >= 1 + model_name = result["models"][0]["name"] # Get specific model model_cfg = MagicMock() - model_cfg.model_dump.return_value = {"name": model_name, "use": "langchain_openai:ChatOpenAI"} + model_cfg.name = model_name + model_cfg.display_name = None + model_cfg.description = None + model_cfg.supports_thinking = False client._app_config.get_model_config.return_value = model_cfg detail = client.get_model(model_name) assert detail["name"] == model_name @@ -877,12 +936,13 @@ class TestScenarioConfigManagement: skill = MagicMock() skill.name = "web-search" skill.description = "Search the web" + skill.license = "MIT" skill.category = "public" skill.enabled = True with patch("src.skills.loader.load_skills", return_value=[skill]): - skills = client.list_skills() - assert len(skills) == 1 + skills_result = client.list_skills() + assert len(skills_result["skills"]) == 1 # Get specific skill with patch("src.skills.loader.load_skills", return_value=[skill]): @@ -912,7 +972,7 @@ class TestScenarioConfigManagement: patch("src.client.reload_extensions_config", return_value=reloaded_config), ): mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}}) - assert "my-mcp" in mcp_result + assert "my-mcp" in mcp_result["mcp_servers"] assert client._agent is None # Agent invalidated # --- Skill toggle --- @@ -1072,8 +1132,8 @@ class TestScenarioThreadIsolation: files_a = client.list_uploads("thread-a") files_b = client.list_uploads("thread-b") - assert len(files_a) == 1 - assert len(files_b) == 0 + assert files_a["count"] == 1 + assert files_b["count"] == 0 def test_artifacts_isolated_per_thread(self, client): """Artifacts in thread-A are not accessible from thread-B.""" @@ -1168,12 +1228,13 @@ class TestScenarioSkillInstallAndUse: installed_skill = MagicMock() installed_skill.name = "my-analyzer" installed_skill.description = "Analyze code" + installed_skill.license = "MIT" installed_skill.category = "custom" installed_skill.enabled = True with patch("src.skills.loader.load_skills", return_value=[installed_skill]): - skills = client.list_skills() - assert any(s["name"] == "my-analyzer" for s in skills) + skills_result = client.list_skills() + assert any(s["name"] == "my-analyzer" for s in skills_result["skills"]) # Step 3: Disable it disabled_skill = MagicMock() @@ -1204,7 +1265,7 @@ class TestScenarioEdgeCases: """Scenario: Edge cases and error boundaries in realistic workflows.""" def test_empty_stream_response(self, client): - """Agent produces no messages — only done event.""" + """Agent produces no messages — only values + end events.""" agent = _make_agent_mock([{"messages": []}]) with ( @@ -1213,8 +1274,10 @@ class TestScenarioEdgeCases: ): events = list(client.stream("hi", thread_id="t-empty")) - assert len(events) == 1 - assert events[0].type == "done" + # values event (empty messages) + end + assert len(events) == 2 + assert events[0].type == "values" + assert events[-1].type == "end" def test_chat_on_empty_response(self, client): """chat() returns empty string for no-message response.""" @@ -1229,12 +1292,12 @@ class TestScenarioEdgeCases: assert result == "" def test_multiple_title_changes(self, client): - """Only distinct title changes produce events.""" + """Title changes are carried in values events.""" ai = AIMessage(content="ok", id="ai-1") chunks = [ {"messages": [ai], "title": "First Title"}, - {"messages": [], "title": "First Title"}, # same — should NOT emit - {"messages": [], "title": "Second Title"}, # different — should emit + {"messages": [], "title": "First Title"}, # same title repeated + {"messages": [], "title": "Second Title"}, # different title ] agent = _make_agent_mock(chunks) @@ -1244,13 +1307,15 @@ class TestScenarioEdgeCases: ): events = list(client.stream("hi", thread_id="t-titles")) - title_events = [e for e in events if e.type == "title"] - assert len(title_events) == 2 - assert title_events[0].data["title"] == "First Title" - assert title_events[1].data["title"] == "Second Title" + # Every chunk produces a values event with the title + values_events = [e for e in events if e.type == "values"] + assert len(values_events) == 3 + assert values_events[0].data["title"] == "First Title" + assert values_events[1].data["title"] == "First Title" + assert values_events[2].data["title"] == "Second Title" def test_concurrent_tool_calls_in_single_message(self, client): - """Agent produces multiple tool_calls in one AIMessage.""" + """Agent produces multiple tool_calls in one AIMessage — emitted as single messages-tuple.""" ai = AIMessage(content="", id="ai-1", tool_calls=[ {"name": "web_search", "args": {"q": "a"}, "id": "tc-1"}, {"name": "web_search", "args": {"q": "b"}, "id": "tc-2"}, @@ -1265,9 +1330,11 @@ class TestScenarioEdgeCases: ): events = list(client.stream("do things", thread_id="t-parallel")) - tc_events = [e for e in events if e.type == "tool_call"] - assert len(tc_events) == 3 - assert {e.data["id"] for e in tc_events} == {"tc-1", "tc-2", "tc-3"} + tc_events = _tool_call_events(events) + assert len(tc_events) == 1 # One messages-tuple event for the AIMessage + tool_calls = tc_events[0].data["tool_calls"] + assert len(tool_calls) == 3 + assert {tc["id"] for tc in tool_calls} == {"tc-1", "tc-2", "tc-3"} def test_upload_convertible_file_conversion_failure(self, client): """Upload a .pdf file where conversion fails — file still uploaded, no markdown.""" @@ -1284,9 +1351,223 @@ class TestScenarioEdgeCases: patch("src.gateway.routers.uploads.CONVERTIBLE_EXTENSIONS", {".pdf"}), patch("src.gateway.routers.uploads.convert_file_to_markdown", side_effect=Exception("conversion failed")), ): - results = client.upload_files("t-pdf-fail", [pdf_file]) + result = client.upload_files("t-pdf-fail", [pdf_file]) - assert len(results) == 1 - assert results[0]["filename"] == "doc.pdf" - assert "markdown_file" not in results[0] # Conversion failed gracefully + assert result["success"] is True + assert len(result["files"]) == 1 + assert result["files"][0]["filename"] == "doc.pdf" + assert "markdown_file" not in result["files"][0] # Conversion failed gracefully assert (uploads_dir / "doc.pdf").exists() # File still uploaded + + +# --------------------------------------------------------------------------- +# Gateway conformance — validate client output against Gateway Pydantic models +# --------------------------------------------------------------------------- + +class TestGatewayConformance: + """Validate that DeerFlowClient return dicts conform to Gateway Pydantic response models. + + Each test calls a client method, then parses the result through the + corresponding Gateway response model. If the client drifts (missing or + wrong-typed fields), Pydantic raises ``ValidationError`` and CI catches it. + """ + + def test_list_models(self, mock_app_config): + model = MagicMock() + model.name = "test-model" + model.display_name = "Test Model" + model.description = "A test model" + model.supports_thinking = False + mock_app_config.models = [model] + + with patch("src.client.get_app_config", return_value=mock_app_config): + client = DeerFlowClient() + + result = client.list_models() + parsed = ModelsListResponse(**result) + assert len(parsed.models) == 1 + assert parsed.models[0].name == "test-model" + + def test_get_model(self, mock_app_config): + model = MagicMock() + model.name = "test-model" + model.display_name = "Test Model" + model.description = "A test model" + model.supports_thinking = True + mock_app_config.models = [model] + mock_app_config.get_model_config.return_value = model + + with patch("src.client.get_app_config", return_value=mock_app_config): + client = DeerFlowClient() + + result = client.get_model("test-model") + assert result is not None + parsed = ModelResponse(**result) + assert parsed.name == "test-model" + + def test_list_skills(self, client): + skill = MagicMock() + skill.name = "web-search" + skill.description = "Search the web" + skill.license = "MIT" + skill.category = "public" + skill.enabled = True + + with patch("src.skills.loader.load_skills", return_value=[skill]): + result = client.list_skills() + + parsed = SkillsListResponse(**result) + assert len(parsed.skills) == 1 + assert parsed.skills[0].name == "web-search" + + def test_get_skill(self, client): + skill = MagicMock() + skill.name = "web-search" + skill.description = "Search the web" + skill.license = "MIT" + skill.category = "public" + skill.enabled = True + + with patch("src.skills.loader.load_skills", return_value=[skill]): + result = client.get_skill("web-search") + + assert result is not None + parsed = SkillResponse(**result) + assert parsed.name == "web-search" + + def test_install_skill(self, client, tmp_path): + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: my-skill\ndescription: A test skill\n---\nBody\n" + ) + + archive = tmp_path / "my-skill.skill" + with zipfile.ZipFile(archive, "w") as zf: + zf.write(skill_dir / "SKILL.md", "my-skill/SKILL.md") + + custom_dir = tmp_path / "custom" + custom_dir.mkdir() + with patch("src.skills.loader.get_skills_root_path", return_value=tmp_path): + result = client.install_skill(archive) + + parsed = SkillInstallResponse(**result) + assert parsed.success is True + assert parsed.skill_name == "my-skill" + + def test_get_mcp_config(self, client): + server = MagicMock() + server.model_dump.return_value = { + "enabled": True, + "type": "stdio", + "command": "npx", + "args": ["-y", "server"], + "env": {}, + "url": None, + "headers": {}, + "description": "test server", + } + ext_config = MagicMock() + ext_config.mcp_servers = {"test": server} + + with patch("src.client.get_extensions_config", return_value=ext_config): + result = client.get_mcp_config() + + parsed = McpConfigResponse(**result) + assert "test" in parsed.mcp_servers + + def test_update_mcp_config(self, client, tmp_path): + server = MagicMock() + server.model_dump.return_value = { + "enabled": True, + "type": "stdio", + "command": "npx", + "args": [], + "env": {}, + "url": None, + "headers": {}, + "description": "", + } + ext_config = MagicMock() + ext_config.mcp_servers = {"srv": server} + ext_config.skills = {} + + config_file = tmp_path / "extensions_config.json" + config_file.write_text("{}") + + with ( + patch("src.client.get_extensions_config", return_value=ext_config), + patch("src.client.ExtensionsConfig.resolve_config_path", return_value=config_file), + patch("src.client.reload_extensions_config", return_value=ext_config), + ): + result = client.update_mcp_config({"srv": server.model_dump.return_value}) + + parsed = McpConfigResponse(**result) + assert "srv" in parsed.mcp_servers + + def test_upload_files(self, client, tmp_path): + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + + src_file = tmp_path / "hello.txt" + src_file.write_text("hello") + + with patch.object(DeerFlowClient, "_get_uploads_dir", return_value=uploads_dir): + result = client.upload_files("t-conform", [src_file]) + + parsed = UploadResponse(**result) + assert parsed.success is True + assert len(parsed.files) == 1 + + def test_get_memory_config(self, client): + mem_cfg = MagicMock() + mem_cfg.enabled = True + mem_cfg.storage_path = ".deer-flow/memory.json" + mem_cfg.debounce_seconds = 30 + mem_cfg.max_facts = 100 + mem_cfg.fact_confidence_threshold = 0.7 + mem_cfg.injection_enabled = True + mem_cfg.max_injection_tokens = 2000 + + with patch("src.config.memory_config.get_memory_config", return_value=mem_cfg): + result = client.get_memory_config() + + parsed = MemoryConfigResponse(**result) + assert parsed.enabled is True + assert parsed.max_facts == 100 + + def test_get_memory_status(self, client): + mem_cfg = MagicMock() + mem_cfg.enabled = True + mem_cfg.storage_path = ".deer-flow/memory.json" + mem_cfg.debounce_seconds = 30 + mem_cfg.max_facts = 100 + mem_cfg.fact_confidence_threshold = 0.7 + mem_cfg.injection_enabled = True + mem_cfg.max_injection_tokens = 2000 + + memory_data = { + "version": "1.0", + "lastUpdated": "", + "user": { + "workContext": {"summary": "", "updatedAt": ""}, + "personalContext": {"summary": "", "updatedAt": ""}, + "topOfMind": {"summary": "", "updatedAt": ""}, + }, + "history": { + "recentMonths": {"summary": "", "updatedAt": ""}, + "earlierContext": {"summary": "", "updatedAt": ""}, + "longTermBackground": {"summary": "", "updatedAt": ""}, + }, + "facts": [], + } + + with ( + patch("src.config.memory_config.get_memory_config", return_value=mem_cfg), + patch("src.agents.memory.updater.get_memory_data", return_value=memory_data), + ): + result = client.get_memory_status() + + parsed = MemoryStatusResponse(**result) + assert parsed.config.enabled is True + assert parsed.data.version == "1.0" diff --git a/backend/tests/test_client_live.py b/backend/tests/test_client_live.py index b1d8d05..3785df6 100644 --- a/backend/tests/test_client_live.py +++ b/backend/tests/test_client_live.py @@ -8,7 +8,6 @@ They are skipped in CI and must be run explicitly: import json import os -import tempfile from pathlib import Path import pytest @@ -68,26 +67,27 @@ class TestLiveBasicChat: # =========================================================================== class TestLiveStreaming: - def test_stream_yields_message_and_done(self, client): - """stream() produces at least one message event and ends with done.""" + def test_stream_yields_messages_tuple_and_end(self, client): + """stream() produces at least one messages-tuple event and ends with end.""" events = list(client.stream("Say hi in one word.")) types = [e.type for e in events] - assert "message" in types, f"Expected 'message' event, got: {types}" - assert types[-1] == "done" + assert "messages-tuple" in types, f"Expected 'messages-tuple' event, got: {types}" + assert "values" in types, f"Expected 'values' event, got: {types}" + assert types[-1] == "end" for e in events: assert isinstance(e, StreamEvent) print(f" [{e.type}] {e.data}") - def test_stream_message_content_nonempty(self, client): - """Streamed message events contain non-empty content.""" - messages = [ + def test_stream_ai_content_nonempty(self, client): + """Streamed messages-tuple AI events contain non-empty content.""" + ai_messages = [ e for e in client.stream("What color is the sky? One word.") - if e.type == "message" + if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content") ] - assert len(messages) >= 1 - for m in messages: + assert len(ai_messages) >= 1 + for m in ai_messages: assert len(m.data.get("content", "")) > 0 @@ -108,16 +108,18 @@ class TestLiveToolUse: for e in events: print(f" [{e.type}] {e.data}") - # Should have tool_call + tool_result + message - assert "tool_call" in types, f"Expected tool_call, got: {types}" - assert "tool_result" in types, f"Expected tool_result, got: {types}" - assert "message" in types + # All message events are now messages-tuple + mt_events = [e for e in events if e.type == "messages-tuple"] + tc_events = [e for e in mt_events if e.data.get("type") == "ai" and "tool_calls" in e.data] + tr_events = [e for e in mt_events if e.data.get("type") == "tool"] + ai_events = [e for e in mt_events if e.data.get("type") == "ai" and e.data.get("content")] - tc = next(e for e in events if e.type == "tool_call") - assert tc.data["name"] == "bash" + assert len(tc_events) >= 1, f"Expected tool_call event, got types: {types}" + assert len(tr_events) >= 1, f"Expected tool result event, got types: {types}" + assert len(ai_events) >= 1 - tr = next(e for e in events if e.type == "tool_result") - assert "LIVE_TEST_OK" in tr.data["content"] + assert tc_events[0].data["tool_calls"][0]["name"] == "bash" + assert "LIVE_TEST_OK" in tr_events[0].data["content"] def test_agent_uses_ls_tool(self, client): """Agent uses ls tool to list a directory.""" @@ -129,9 +131,9 @@ class TestLiveToolUse: types = [e.type for e in events] print(f" event types: {types}") - assert "tool_call" in types - tc = next(e for e in events if e.type == "tool_call") - assert tc.data["name"] == "ls" + tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data] + assert len(tc_events) >= 1 + assert tc_events[0].data["tool_calls"][0]["name"] == "ls" # =========================================================================== @@ -153,18 +155,19 @@ class TestLiveMultiToolChain: for e in events: print(f" [{e.type}] {e.data}") - tool_calls = [e for e in events if e.type == "tool_call"] - tool_names = [tc.data["name"] for tc in tool_calls] + tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data] + tool_names = [tc.data["tool_calls"][0]["name"] for tc in tc_events] assert "write_file" in tool_names, f"Expected write_file, got: {tool_names}" assert "read_file" in tool_names, f"Expected read_file, got: {tool_names}" - # Final message should mention the content - messages = [e for e in events if e.type == "message"] - final_text = messages[-1].data["content"] if messages else "" + # Final AI message or tool result should mention the content + ai_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")] + tr_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"] + final_text = ai_events[-1].data["content"] if ai_events else "" assert "integration_test_content" in final_text.lower() or any( "integration_test_content" in e.data.get("content", "") - for e in events if e.type == "tool_result" + for e in tr_events ) @@ -184,30 +187,35 @@ class TestLiveFileUpload: f2.write_text("content B") # Upload - results = client.upload_files(thread_id, [f1, f2]) - assert len(results) == 2 - filenames = {r["filename"] for r in results} + result = client.upload_files(thread_id, [f1, f2]) + assert result["success"] is True + assert len(result["files"]) == 2 + filenames = {r["filename"] for r in result["files"]} assert filenames == {"test_upload_a.txt", "test_upload_b.txt"} - for r in results: - assert r["size"] > 0 + for r in result["files"]: + assert int(r["size"]) > 0 assert r["virtual_path"].startswith("/mnt/user-data/uploads/") + assert "artifact_url" in r print(f" uploaded: {filenames}") # List listed = client.list_uploads(thread_id) - assert len(listed) == 2 - print(f" listed: {[f['filename'] for f in listed]}") + assert listed["count"] == 2 + print(f" listed: {[f['filename'] for f in listed['files']]}") # Delete one - client.delete_upload(thread_id, "test_upload_a.txt") + del_result = client.delete_upload(thread_id, "test_upload_a.txt") + assert del_result["success"] is True remaining = client.list_uploads(thread_id) - assert len(remaining) == 1 - assert remaining[0]["filename"] == "test_upload_b.txt" - print(f" after delete: {[f['filename'] for f in remaining]}") + assert remaining["count"] == 1 + assert remaining["files"][0]["filename"] == "test_upload_b.txt" + print(f" after delete: {[f['filename'] for f in remaining['files']]}") # Delete the other client.delete_upload(thread_id, "test_upload_b.txt") - assert client.list_uploads(thread_id) == [] + empty = client.list_uploads(thread_id) + assert empty["count"] == 0 + assert empty["files"] == [] def test_upload_nonexistent_file_raises(self, client): with pytest.raises(FileNotFoundError): @@ -221,10 +229,15 @@ class TestLiveFileUpload: class TestLiveConfigQueries: def test_list_models_returns_ark(self, client): """list_models() returns the configured ARK model.""" - models = client.list_models() - assert len(models) >= 1 - names = [m["name"] for m in models] + result = client.list_models() + assert "models" in result + assert len(result["models"]) >= 1 + names = [m["name"] for m in result["models"]] assert "ark-model" in names + # Verify Gateway-aligned fields + for m in result["models"]: + assert "display_name" in m + assert "supports_thinking" in m print(f" models: {names}") def test_get_model_found(self, client): @@ -232,6 +245,8 @@ class TestLiveConfigQueries: model = client.get_model("ark-model") assert model is not None assert model["name"] == "ark-model" + assert "display_name" in model + assert "supports_thinking" in model print(f" model detail: {model}") def test_get_model_not_found(self, client): @@ -239,10 +254,11 @@ class TestLiveConfigQueries: def test_list_skills(self, client): """list_skills() runs without error.""" - skills = client.list_skills() - assert isinstance(skills, list) - print(f" skills count: {len(skills)}") - for s in skills[:3]: + result = client.list_skills() + assert "skills" in result + assert isinstance(result["skills"], list) + print(f" skills count: {len(result['skills'])}") + for s in result["skills"][:3]: print(f" - {s['name']}: {s['enabled']}") @@ -264,8 +280,11 @@ class TestLiveArtifact: )) # Verify write happened - tool_calls = [e for e in events if e.type == "tool_call"] - assert any(tc.data["name"] == "write_file" for tc in tool_calls) + tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data] + assert any( + any(tc["name"] == "write_file" for tc in e.data["tool_calls"]) + for e in tc_events + ) # Read artifact content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json")