mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-05-03 18:50:43 +08:00
refactor: split backend into harness (deerflow.*) and app (app.*) (#1131)
* refactor: extract shared utils to break harness→app cross-layer imports Move _validate_skill_frontmatter to src/skills/validation.py and CONVERTIBLE_EXTENSIONS + convert_file_to_markdown to src/utils/file_conversion.py. This eliminates the two reverse dependencies from client.py (harness layer) into gateway/routers/ (app layer), preparing for the harness/app package split. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: split backend/src into harness (deerflow.*) and app (app.*) Physically split the monolithic backend/src/ package into two layers: - **Harness** (`packages/harness/deerflow/`): publishable agent framework package with import prefix `deerflow.*`. Contains agents, sandbox, tools, models, MCP, skills, config, and all core infrastructure. - **App** (`app/`): unpublished application code with import prefix `app.*`. Contains gateway (FastAPI REST API) and channels (IM integrations). Key changes: - Move 13 harness modules to packages/harness/deerflow/ via git mv - Move gateway + channels to app/ via git mv - Rename all imports: src.* → deerflow.* (harness) / app.* (app layer) - Set up uv workspace with deerflow-harness as workspace member - Update langgraph.json, config.example.yaml, all scripts, Docker files - Add build-system (hatchling) to harness pyproject.toml - Add PYTHONPATH=. to gateway startup commands for app.* resolution - Update ruff.toml with known-first-party for import sorting - Update all documentation to reflect new directory structure Boundary rule enforced: harness code never imports from app. All 429 tests pass. Lint clean. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: add harness→app boundary check test and update docs Add test_harness_boundary.py that scans all Python files in packages/harness/deerflow/ and fails if any `from app.*` or `import app.*` statement is found. This enforces the architectural rule that the harness layer never depends on the app layer. Update CLAUDE.md to document the harness/app split architecture, import conventions, and the boundary enforcement test. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add config versioning with auto-upgrade on startup When config.example.yaml schema changes, developers' local config.yaml files can silently become outdated. This adds a config_version field and auto-upgrade mechanism so breaking changes (like src.* → deerflow.* renames) are applied automatically before services start. - Add config_version: 1 to config.example.yaml - Add startup version check warning in AppConfig.from_file() - Add scripts/config-upgrade.sh with migration registry for value replacements - Add `make config-upgrade` target - Auto-run config-upgrade in serve.sh and start-daemon.sh before starting services - Add config error hints in service failure messages Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix comments * fix: update src.* import in test_sandbox_tools_security to deerflow.* Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: handle empty config and search parent dirs for config.example.yaml Address Copilot review comments on PR #1131: - Guard against yaml.safe_load() returning None for empty config files - Search parent directories for config.example.yaml instead of only looking next to config.yaml, fixing detection in common setups Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: correct skills root path depth and config_version type coercion - loader.py: fix get_skills_root_path() to use 5 parent levels (was 3) after harness split, file lives at packages/harness/deerflow/skills/ so parent×3 resolved to backend/packages/harness/ instead of backend/ - app_config.py: coerce config_version to int() before comparison in _check_config_version() to prevent TypeError when YAML stores value as string (e.g. config_version: "1") - tests: add regression tests for both fixes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: update test imports from src.* to deerflow.*/app.* after harness refactor Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
0
backend/packages/harness/deerflow/__init__.py
Normal file
0
backend/packages/harness/deerflow/__init__.py
Normal file
5
backend/packages/harness/deerflow/agents/__init__.py
Normal file
5
backend/packages/harness/deerflow/agents/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .lead_agent import make_lead_agent
|
||||
from .thread_state import SandboxState, ThreadState
|
||||
|
||||
__all__ = ["make_lead_agent", "SandboxState", "ThreadState", "get_checkpointer", "reset_checkpointer", "make_checkpointer"]
|
||||
@@ -0,0 +1,9 @@
|
||||
from .async_provider import make_checkpointer
|
||||
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
|
||||
|
||||
__all__ = [
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"checkpointer_context",
|
||||
"make_checkpointer",
|
||||
]
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Async checkpointer factory.
|
||||
|
||||
Provides an **async context manager** for long-running async servers that need
|
||||
proper resource cleanup.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.agents.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
_resolve_sqlite_conn_str,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs and tears down a checkpointer."""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
import pathlib
|
||||
|
||||
conn_str = _resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
# Only create parent directories for real filesystem paths
|
||||
if conn_str != ":memory:" and not conn_str.startswith("file:"):
|
||||
pathlib.Path(conn_str).parent.mkdir(parents=True, exist_ok=True)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit — no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@@ -0,0 +1,201 @@
|
||||
"""Sync checkpointer factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for LangGraph
|
||||
graph compilation and CLI tools.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.config.paths import resolve_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants — imported by aio.provider too
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_sqlite_conn_str(raw: str) -> str:
|
||||
"""Return a SQLite connection string ready for use with ``SqliteSaver``.
|
||||
|
||||
SQLite special strings (``":memory:"`` and ``file:`` URIs) are returned
|
||||
unchanged. Plain filesystem paths — relative or absolute — are resolved
|
||||
to an absolute string via :func:`resolve_path`.
|
||||
"""
|
||||
if raw == ":memory:" or raw.startswith("file:"):
|
||||
return raw
|
||||
return str(resolve_path(raw))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
"""Context manager that creates and tears down a sync checkpointer.
|
||||
|
||||
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
|
||||
underlying connections or pools is handled by higher-level helpers in
|
||||
this module (such as the singleton factory or context manager); this
|
||||
function does not return a separate cleanup callback.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = _resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using PostgresSaver")
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
if _app_config is None:
|
||||
# Only load config if it hasn't been initialized yet
|
||||
# In tests, config may be set directly via set_checkpointer_config()
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected
|
||||
# Tests will set config directly via set_checkpointer_config()
|
||||
pass
|
||||
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
def reset_checkpointer() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agent import make_lead_agent
|
||||
|
||||
__all__ = ["make_lead_agent"]
|
||||
334
backend/packages/harness/deerflow/agents/lead_agent/agent.py
Normal file
334
backend/packages/harness/deerflow/agents/lead_agent/agent.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
app_config = get_app_config()
|
||||
default_model_name = app_config.models[0].name if app_config.models else None
|
||||
if default_model_name is None:
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
|
||||
if requested_model_name and app_config.get_model_config(requested_model_name):
|
||||
return requested_model_name
|
||||
|
||||
if requested_model_name and requested_model_name != default_model_name:
|
||||
logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.")
|
||||
return default_model_name
|
||||
|
||||
|
||||
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
# Prepare trigger parameter
|
||||
trigger = None
|
||||
if config.trigger is not None:
|
||||
if isinstance(config.trigger, list):
|
||||
trigger = [t.to_tuple() for t in config.trigger]
|
||||
else:
|
||||
trigger = config.trigger.to_tuple()
|
||||
|
||||
# Prepare keep parameter
|
||||
keep = config.keep.to_tuple()
|
||||
|
||||
# Prepare model parameter
|
||||
if config.model_name:
|
||||
model = config.model_name
|
||||
else:
|
||||
# Use a lightweight model for summarization to save costs
|
||||
# Falls back to default model if not explicitly specified
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"trigger": trigger,
|
||||
"keep": keep,
|
||||
}
|
||||
|
||||
if config.trim_tokens_to_summarize is not None:
|
||||
kwargs["trim_tokens_to_summarize"] = config.trim_tokens_to_summarize
|
||||
|
||||
if config.summary_prompt is not None:
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
return SummarizationMiddleware(**kwargs)
|
||||
|
||||
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||
"""Create and configure the TodoList middleware.
|
||||
|
||||
Args:
|
||||
is_plan_mode: Whether to enable plan mode with TodoList middleware.
|
||||
|
||||
Returns:
|
||||
TodoMiddleware instance if plan mode is enabled, None otherwise.
|
||||
"""
|
||||
if not is_plan_mode:
|
||||
return None
|
||||
|
||||
# Custom prompts matching DeerFlow's style
|
||||
system_prompt = """
|
||||
<todo_list_system>
|
||||
You have access to the `write_todos` tool to help you manage and track complex multi-step objectives.
|
||||
|
||||
**CRITICAL RULES:**
|
||||
- Mark todos as completed IMMEDIATELY after finishing each step - do NOT batch completions
|
||||
- Keep EXACTLY ONE task as `in_progress` at any time (unless tasks can run in parallel)
|
||||
- Update the todo list in REAL-TIME as you work - this gives users visibility into your progress
|
||||
- DO NOT use this tool for simple tasks (< 3 steps) - just complete them directly
|
||||
|
||||
**When to Use:**
|
||||
This tool is designed for complex objectives that require systematic tracking:
|
||||
- Complex multi-step tasks requiring 3+ distinct steps
|
||||
- Non-trivial tasks needing careful planning and execution
|
||||
- User explicitly requests a todo list
|
||||
- User provides multiple tasks (numbered or comma-separated list)
|
||||
- The plan may need revisions based on intermediate results
|
||||
|
||||
**When NOT to Use:**
|
||||
- Single, straightforward tasks
|
||||
- Trivial tasks (< 3 steps)
|
||||
- Purely conversational or informational requests
|
||||
- Simple tool calls where the approach is obvious
|
||||
|
||||
**Best Practices:**
|
||||
- Break down complex tasks into smaller, actionable steps
|
||||
- Use clear, descriptive task names
|
||||
- Remove tasks that become irrelevant
|
||||
- Add new tasks discovered during implementation
|
||||
- Don't be afraid to revise the todo list as you learn more
|
||||
|
||||
**Task Management:**
|
||||
Writing todos takes time and tokens - use it when helpful for managing complex problems, not for simple requests.
|
||||
</todo_list_system>
|
||||
"""
|
||||
|
||||
tool_description = """Use this tool to create and manage a structured task list for complex work sessions.
|
||||
|
||||
**IMPORTANT: Only use this tool for complex tasks (3+ steps). For simple requests, just do the work directly.**
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this tool in these scenarios:
|
||||
1. **Complex multi-step tasks**: When a task requires 3 or more distinct steps or actions
|
||||
2. **Non-trivial tasks**: Tasks requiring careful planning or multiple operations
|
||||
3. **User explicitly requests todo list**: When the user directly asks you to track tasks
|
||||
4. **Multiple tasks**: When users provide a list of things to be done
|
||||
5. **Dynamic planning**: When the plan may need updates based on intermediate results
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
Skip this tool when:
|
||||
1. The task is straightforward and takes less than 3 steps
|
||||
2. The task is trivial and tracking provides no benefit
|
||||
3. The task is purely conversational or informational
|
||||
4. It's clear what needs to be done and you can just do it
|
||||
|
||||
## How to Use
|
||||
|
||||
1. **Starting a task**: Mark it as `in_progress` BEFORE beginning work
|
||||
2. **Completing a task**: Mark it as `completed` IMMEDIATELY after finishing
|
||||
3. **Updating the list**: Add new tasks, remove irrelevant ones, or update descriptions as needed
|
||||
4. **Multiple updates**: You can make several updates at once (e.g., complete one task and start the next)
|
||||
|
||||
## Task States
|
||||
|
||||
- `pending`: Task not yet started
|
||||
- `in_progress`: Currently working on (can have multiple if tasks run in parallel)
|
||||
- `completed`: Task finished successfully
|
||||
|
||||
## Task Completion Requirements
|
||||
|
||||
**CRITICAL: Only mark a task as completed when you have FULLY accomplished it.**
|
||||
|
||||
Never mark a task as completed if:
|
||||
- There are unresolved issues or errors
|
||||
- Work is partial or incomplete
|
||||
- You encountered blockers preventing completion
|
||||
- You couldn't find necessary resources or dependencies
|
||||
- Quality standards haven't been met
|
||||
|
||||
If blocked, keep the task as `in_progress` and create a new task describing what needs to be resolved.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Create specific, actionable items
|
||||
- Break complex tasks into smaller, manageable steps
|
||||
- Use clear, descriptive task names
|
||||
- Update task status in real-time as you work
|
||||
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
|
||||
- Remove tasks that are no longer relevant
|
||||
- **IMPORTANT**: When you write the todo list, mark your first task(s) as `in_progress` immediately
|
||||
- **IMPORTANT**: Unless all tasks are completed, always have at least one task `in_progress` to show progress
|
||||
|
||||
Being proactive with task management demonstrates thoroughness and ensures all requirements are completed successfully.
|
||||
|
||||
**Remember**: If you only need a few tool calls to complete a task and it's clear what to do, it's better to just do the task directly and NOT use this tool at all.
|
||||
"""
|
||||
|
||||
return TodoMiddleware(system_prompt=system_prompt, tool_description=tool_description)
|
||||
|
||||
|
||||
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
|
||||
# DanglingToolCallMiddleware patches missing ToolMessages before model sees the history
|
||||
# SummarizationMiddleware should be early to reduce context before other processing
|
||||
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
|
||||
# TitleMiddleware generates title after first exchange
|
||||
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
if summarization_middleware is not None:
|
||||
middlewares.append(summarization_middleware)
|
||||
|
||||
# Add TodoList middleware if plan mode is enabled
|
||||
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
|
||||
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
|
||||
if todo_list_middleware is not None:
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
# Add TitleMiddleware
|
||||
middlewares.append(TitleMiddleware())
|
||||
|
||||
# Add MemoryMiddleware (after TitleMiddleware)
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name))
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
||||
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
||||
if subagent_enabled:
|
||||
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
|
||||
cfg = config.get("configurable", {})
|
||||
|
||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||
reasoning_effort = cfg.get("reasoning_effort", None)
|
||||
requested_model_name: str | None = cfg.get("model_name") or cfg.get("model")
|
||||
is_plan_mode = cfg.get("is_plan_mode", False)
|
||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||
is_bootstrap = cfg.get("is_bootstrap", False)
|
||||
agent_name = cfg.get("agent_name")
|
||||
|
||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||
# Custom agent model or fallback to global/default model resolution
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else _resolve_model_name()
|
||||
|
||||
# Final model name resolution with request override, then agent config, then global default
|
||||
model_name = requested_model_name or agent_model_name
|
||||
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
|
||||
if model_config is None:
|
||||
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
||||
if thinking_enabled and not model_config.supports_thinking:
|
||||
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
|
||||
thinking_enabled = False
|
||||
|
||||
logger.info(
|
||||
"Create Agent(%s) -> thinking_enabled: %s, reasoning_effort: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
|
||||
agent_name or "default",
|
||||
thinking_enabled,
|
||||
reasoning_effort,
|
||||
model_name,
|
||||
is_plan_mode,
|
||||
subagent_enabled,
|
||||
max_concurrent_subagents,
|
||||
)
|
||||
|
||||
# Inject run metadata for LangSmith trace tagging
|
||||
if "metadata" not in config:
|
||||
config["metadata"] = {}
|
||||
|
||||
config["metadata"].update(
|
||||
{
|
||||
"agent_name": agent_name or "default",
|
||||
"model_name": model_name or "default",
|
||||
"thinking_enabled": thinking_enabled,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"is_plan_mode": is_plan_mode,
|
||||
"subagent_enabled": subagent_enabled,
|
||||
}
|
||||
)
|
||||
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
system_prompt = apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"]))
|
||||
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
system_prompt=system_prompt,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
# Default lead agent (unchanged behavior)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
409
backend/packages/harness/deerflow/agents/lead_agent/prompt.py
Normal file
409
backend/packages/harness/deerflow/agents/lead_agent/prompt.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from datetime import datetime
|
||||
|
||||
from deerflow.config.agents_config import load_agent_soul
|
||||
from deerflow.skills import load_skills
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
|
||||
|
||||
Returns:
|
||||
Formatted subagent section string.
|
||||
"""
|
||||
n = max_concurrent
|
||||
return f"""<subagent_system>
|
||||
**🚀 SUBAGENT MODE ACTIVE - DECOMPOSE, DELEGATE, SYNTHESIZE**
|
||||
|
||||
You are running with subagent capabilities enabled. Your role is to be a **task orchestrator**:
|
||||
1. **DECOMPOSE**: Break complex tasks into parallel sub-tasks
|
||||
2. **DELEGATE**: Launch multiple subagents simultaneously using parallel `task` calls
|
||||
3. **SYNTHESIZE**: Collect and integrate results into a coherent answer
|
||||
|
||||
**CORE PRINCIPLE: Complex tasks should be decomposed and distributed across multiple subagents for parallel execution.**
|
||||
|
||||
**⛔ HARD CONCURRENCY LIMIT: MAXIMUM {n} `task` CALLS PER RESPONSE. THIS IS NOT OPTIONAL.**
|
||||
- Each response, you may include **at most {n}** `task` tool calls. Any excess calls are **silently discarded** by the system — you will lose that work.
|
||||
- **Before launching subagents, you MUST count your sub-tasks in your thinking:**
|
||||
- If count ≤ {n}: Launch all in this response.
|
||||
- If count > {n}: **Pick the {n} most important/foundational sub-tasks for this turn.** Save the rest for the next turn.
|
||||
- **Multi-batch execution** (for >{n} sub-tasks):
|
||||
- Turn 1: Launch sub-tasks 1-{n} in parallel → wait for results
|
||||
- Turn 2: Launch next batch in parallel → wait for results
|
||||
- ... continue until all sub-tasks are complete
|
||||
- Final turn: Synthesize ALL results into a coherent answer
|
||||
- **Example thinking pattern**: "I identified 6 sub-tasks. Since the limit is {n} per turn, I will launch the first {n} now, and the rest in the next turn."
|
||||
|
||||
**Available Subagents:**
|
||||
- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.
|
||||
- **bash**: For command execution (git, build, test, deploy operations)
|
||||
|
||||
**Your Orchestration Strategy:**
|
||||
|
||||
✅ **DECOMPOSE + PARALLEL EXECUTION (Preferred Approach):**
|
||||
|
||||
For complex queries, break them down into focused sub-tasks and execute in parallel batches (max {n} per turn):
|
||||
|
||||
**Example 1: "Why is Tencent's stock price declining?" (3 sub-tasks → 1 batch)**
|
||||
→ Turn 1: Launch 3 subagents in parallel:
|
||||
- Subagent 1: Recent financial reports, earnings data, and revenue trends
|
||||
- Subagent 2: Negative news, controversies, and regulatory issues
|
||||
- Subagent 3: Industry trends, competitor performance, and market sentiment
|
||||
→ Turn 2: Synthesize results
|
||||
|
||||
**Example 2: "Compare 5 cloud providers" (5 sub-tasks → multi-batch)**
|
||||
→ Turn 1: Launch {n} subagents in parallel (first batch)
|
||||
→ Turn 2: Launch remaining subagents in parallel
|
||||
→ Final turn: Synthesize ALL results into comprehensive comparison
|
||||
|
||||
**Example 3: "Refactor the authentication system"**
|
||||
→ Turn 1: Launch 3 subagents in parallel:
|
||||
- Subagent 1: Analyze current auth implementation and technical debt
|
||||
- Subagent 2: Research best practices and security patterns
|
||||
- Subagent 3: Review related tests, documentation, and vulnerabilities
|
||||
→ Turn 2: Synthesize results
|
||||
|
||||
✅ **USE Parallel Subagents (max {n} per turn) when:**
|
||||
- **Complex research questions**: Requires multiple information sources or perspectives
|
||||
- **Multi-aspect analysis**: Task has several independent dimensions to explore
|
||||
- **Large codebases**: Need to analyze different parts simultaneously
|
||||
- **Comprehensive investigations**: Questions requiring thorough coverage from multiple angles
|
||||
|
||||
❌ **DO NOT use subagents (execute directly) when:**
|
||||
- **Task cannot be decomposed**: If you can't break it into 2+ meaningful parallel sub-tasks, execute directly
|
||||
- **Ultra-simple actions**: Read one file, quick edits, single commands
|
||||
- **Need immediate clarification**: Must ask user before proceeding
|
||||
- **Meta conversation**: Questions about conversation history
|
||||
- **Sequential dependencies**: Each step depends on previous results (do steps yourself sequentially)
|
||||
|
||||
**CRITICAL WORKFLOW** (STRICTLY follow this before EVERY action):
|
||||
1. **COUNT**: In your thinking, list all sub-tasks and count them explicitly: "I have N sub-tasks"
|
||||
2. **PLAN BATCHES**: If N > {n}, explicitly plan which sub-tasks go in which batch:
|
||||
- "Batch 1 (this turn): first {n} sub-tasks"
|
||||
- "Batch 2 (next turn): next batch of sub-tasks"
|
||||
3. **EXECUTE**: Launch ONLY the current batch (max {n} `task` calls). Do NOT launch sub-tasks from future batches.
|
||||
4. **REPEAT**: After results return, launch the next batch. Continue until all batches complete.
|
||||
5. **SYNTHESIZE**: After ALL batches are done, synthesize all results.
|
||||
6. **Cannot decompose** → Execute directly using available tools (bash, read_file, web_search, etc.)
|
||||
|
||||
**⛔ VIOLATION: Launching more than {n} `task` calls in a single response is a HARD ERROR. The system WILL discard excess calls and you WILL lose work. Always batch.**
|
||||
|
||||
**Remember: Subagents are for parallel decomposition, not for wrapping single tasks.**
|
||||
|
||||
**How It Works:**
|
||||
- The task tool runs subagents asynchronously in the background
|
||||
- The backend automatically polls for completion (you don't need to poll)
|
||||
- The tool call will block until the subagent completes its work
|
||||
- Once complete, the result is returned to you directly
|
||||
|
||||
**Usage Example 1 - Single Batch (≤{n} sub-tasks):**
|
||||
|
||||
```python
|
||||
# User asks: "Why is Tencent's stock price declining?"
|
||||
# Thinking: 3 sub-tasks → fits in 1 batch
|
||||
|
||||
# Turn 1: Launch 3 subagents in parallel
|
||||
task(description="Tencent financial data", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Tencent news & regulation", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Industry & market trends", prompt="...", subagent_type="general-purpose")
|
||||
# All 3 run in parallel → synthesize results
|
||||
```
|
||||
|
||||
**Usage Example 2 - Multiple Batches (>{n} sub-tasks):**
|
||||
|
||||
```python
|
||||
# User asks: "Compare AWS, Azure, GCP, Alibaba Cloud, and Oracle Cloud"
|
||||
# Thinking: 5 sub-tasks → need multiple batches (max {n} per batch)
|
||||
|
||||
# Turn 1: Launch first batch of {n}
|
||||
task(description="AWS analysis", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Azure analysis", prompt="...", subagent_type="general-purpose")
|
||||
task(description="GCP analysis", prompt="...", subagent_type="general-purpose")
|
||||
|
||||
# Turn 2: Launch remaining batch (after first batch completes)
|
||||
task(description="Alibaba Cloud analysis", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Oracle Cloud analysis", prompt="...", subagent_type="general-purpose")
|
||||
|
||||
# Turn 3: Synthesize ALL results from both batches
|
||||
```
|
||||
|
||||
**Counter-Example - Direct Execution (NO subagents):**
|
||||
|
||||
```python
|
||||
# User asks: "Run the tests"
|
||||
# Thinking: Cannot decompose into parallel sub-tasks
|
||||
# → Execute directly
|
||||
|
||||
bash("npm test") # Direct execution, not task()
|
||||
```
|
||||
|
||||
**CRITICAL**:
|
||||
- **Max {n} `task` calls per turn** - the system enforces this, excess calls are discarded
|
||||
- Only use `task` when you can launch 2+ subagents in parallel
|
||||
- Single task = No value from subagents = Execute directly
|
||||
- For >{n} sub-tasks, use sequential batches of {n} across multiple turns
|
||||
</subagent_system>"""
|
||||
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """
|
||||
<role>
|
||||
You are {agent_name}, an open-source super agent.
|
||||
</role>
|
||||
|
||||
{soul}
|
||||
{memory_context}
|
||||
|
||||
<thinking_style>
|
||||
- Think concisely and strategically about the user's request BEFORE taking action
|
||||
- Break down the task: What is clear? What is ambiguous? What is missing?
|
||||
- **PRIORITY CHECK: If anything is unclear, missing, or has multiple interpretations, you MUST ask for clarification FIRST - do NOT proceed with work**
|
||||
{subagent_thinking}- Never write down your full final answer or report in thinking process, but only outline
|
||||
- CRITICAL: After thinking, you MUST provide your actual response to the user. Thinking is for planning, the response is for delivery.
|
||||
- Your response must contain the actual answer, not just a reference to what you thought about
|
||||
</thinking_style>
|
||||
|
||||
<clarification_system>
|
||||
**WORKFLOW PRIORITY: CLARIFY → PLAN → ACT**
|
||||
1. **FIRST**: Analyze the request in your thinking - identify what's unclear, missing, or ambiguous
|
||||
2. **SECOND**: If clarification is needed, call `ask_clarification` tool IMMEDIATELY - do NOT start working
|
||||
3. **THIRD**: Only after all clarifications are resolved, proceed with planning and execution
|
||||
|
||||
**CRITICAL RULE: Clarification ALWAYS comes BEFORE action. Never start working and clarify mid-execution.**
|
||||
|
||||
**MANDATORY Clarification Scenarios - You MUST call ask_clarification BEFORE starting work when:**
|
||||
|
||||
1. **Missing Information** (`missing_info`): Required details not provided
|
||||
- Example: User says "create a web scraper" but doesn't specify the target website
|
||||
- Example: "Deploy the app" without specifying environment
|
||||
- **REQUIRED ACTION**: Call ask_clarification to get the missing information
|
||||
|
||||
2. **Ambiguous Requirements** (`ambiguous_requirement`): Multiple valid interpretations exist
|
||||
- Example: "Optimize the code" could mean performance, readability, or memory usage
|
||||
- Example: "Make it better" is unclear what aspect to improve
|
||||
- **REQUIRED ACTION**: Call ask_clarification to clarify the exact requirement
|
||||
|
||||
3. **Approach Choices** (`approach_choice`): Several valid approaches exist
|
||||
- Example: "Add authentication" could use JWT, OAuth, session-based, or API keys
|
||||
- Example: "Store data" could use database, files, cache, etc.
|
||||
- **REQUIRED ACTION**: Call ask_clarification to let user choose the approach
|
||||
|
||||
4. **Risky Operations** (`risk_confirmation`): Destructive actions need confirmation
|
||||
- Example: Deleting files, modifying production configs, database operations
|
||||
- Example: Overwriting existing code or data
|
||||
- **REQUIRED ACTION**: Call ask_clarification to get explicit confirmation
|
||||
|
||||
5. **Suggestions** (`suggestion`): You have a recommendation but want approval
|
||||
- Example: "I recommend refactoring this code. Should I proceed?"
|
||||
- **REQUIRED ACTION**: Call ask_clarification to get approval
|
||||
|
||||
**STRICT ENFORCEMENT:**
|
||||
- ❌ DO NOT start working and then ask for clarification mid-execution - clarify FIRST
|
||||
- ❌ DO NOT skip clarification for "efficiency" - accuracy matters more than speed
|
||||
- ❌ DO NOT make assumptions when information is missing - ALWAYS ask
|
||||
- ❌ DO NOT proceed with guesses - STOP and call ask_clarification first
|
||||
- ✅ Analyze the request in thinking → Identify unclear aspects → Ask BEFORE any action
|
||||
- ✅ If you identify the need for clarification in your thinking, you MUST call the tool IMMEDIATELY
|
||||
- ✅ After calling ask_clarification, execution will be interrupted automatically
|
||||
- ✅ Wait for user response - do NOT continue with assumptions
|
||||
|
||||
**How to Use:**
|
||||
```python
|
||||
ask_clarification(
|
||||
question="Your specific question here?",
|
||||
clarification_type="missing_info", # or other type
|
||||
context="Why you need this information", # optional but recommended
|
||||
options=["option1", "option2"] # optional, for choices
|
||||
)
|
||||
```
|
||||
|
||||
**Example:**
|
||||
User: "Deploy the application"
|
||||
You (thinking): Missing environment info - I MUST ask for clarification
|
||||
You (action): ask_clarification(
|
||||
question="Which environment should I deploy to?",
|
||||
clarification_type="approach_choice",
|
||||
context="I need to know the target environment for proper configuration",
|
||||
options=["development", "staging", "production"]
|
||||
)
|
||||
[Execution stops - wait for user response]
|
||||
|
||||
User: "staging"
|
||||
You: "Deploying to staging..." [proceed]
|
||||
</clarification_system>
|
||||
|
||||
{skills_section}
|
||||
|
||||
{subagent_section}
|
||||
|
||||
<working_directory existed="true">
|
||||
- User uploads: `/mnt/user-data/uploads` - Files uploaded by the user (automatically listed in context)
|
||||
- User workspace: `/mnt/user-data/workspace` - Working directory for temporary files
|
||||
- Output files: `/mnt/user-data/outputs` - Final deliverables must be saved here
|
||||
|
||||
**File Management:**
|
||||
- Uploaded files are automatically listed in the <uploaded_files> section before each request
|
||||
- Use `read_file` tool to read uploaded files using their paths from the list
|
||||
- For PDF, PPT, Excel, and Word files, converted Markdown versions (*.md) are available alongside originals
|
||||
- All temporary work happens in `/mnt/user-data/workspace`
|
||||
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
|
||||
</working_directory>
|
||||
|
||||
<response_style>
|
||||
- Clear and Concise: Avoid over-formatting unless requested
|
||||
- Natural Tone: Use paragraphs and prose, not bullet points by default
|
||||
- Action-Oriented: Focus on delivering results, not explaining processes
|
||||
</response_style>
|
||||
|
||||
<citations>
|
||||
- When to Use: After web_search, include citations if applicable
|
||||
- Format: Use Markdown link format `[citation:TITLE](URL)`
|
||||
- Example:
|
||||
```markdown
|
||||
The key AI trends for 2026 include enhanced reasoning capabilities and multimodal integration
|
||||
[citation:AI Trends 2026](https://techcrunch.com/ai-trends).
|
||||
Recent breakthroughs in language models have also accelerated progress
|
||||
[citation:OpenAI Research](https://openai.com/research).
|
||||
```
|
||||
</citations>
|
||||
|
||||
<critical_reminders>
|
||||
- **Clarification First**: ALWAYS clarify unclear/missing/ambiguous requirements BEFORE starting work - never assume or guess
|
||||
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
|
||||
- Progressive Loading: Load resources incrementally as referenced in skills
|
||||
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
|
||||
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
|
||||
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `\n\n` or "```mermaid" to display images in response or Markdown files
|
||||
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
|
||||
- Language Consistency: Keep using the same language as user's
|
||||
- Always Respond: Your thinking is internal. You MUST always provide a visible response to the user after thinking.
|
||||
</critical_reminders>
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
"""
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
return f"""<memory>
|
||||
{memory_content}
|
||||
</memory>
|
||||
"""
|
||||
except Exception as e:
|
||||
print(f"Failed to load memory context: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list.
|
||||
|
||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
suitable for injection into any agent's system prompt.
|
||||
"""
|
||||
skills = load_skills(enabled_only=True)
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
container_base_path = config.skills.container_path
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
if available_skills is not None:
|
||||
skills = [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
skill_items = "\n".join(
|
||||
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
|
||||
)
|
||||
skills_list = f"<available_skills>\n{skill_items}\n</available_skills>"
|
||||
|
||||
return f"""<skill_system>
|
||||
You have access to skills that provide optimized workflows for specific tasks. Each skill contains best practices, frameworks, and references to additional resources.
|
||||
|
||||
**Progressive Loading Pattern:**
|
||||
1. When a user query matches a skill's use case, immediately call `read_file` on the skill's main file using the path attribute provided in the skill tag below
|
||||
2. Read and understand the skill's workflow and instructions
|
||||
3. The skill file contains references to external resources under the same folder
|
||||
4. Load referenced resources only when needed during execution
|
||||
5. Follow the skill's instructions precisely
|
||||
|
||||
**Skills are located at:** {container_base_path}
|
||||
|
||||
{skills_list}
|
||||
|
||||
</skill_system>"""
|
||||
|
||||
|
||||
def get_agent_soul(agent_name: str | None) -> str:
|
||||
# Append SOUL.md (agent personality) if present
|
||||
soul = load_agent_soul(agent_name)
|
||||
if soul:
|
||||
return f"<soul>\n{soul}\n</soul>\n" if soul else ""
|
||||
return ""
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||
|
||||
# Add subagent reminder to critical_reminders if enabled
|
||||
subagent_reminder = (
|
||||
"- **Orchestrator Mode**: You are a task orchestrator - decompose complex tasks into parallel sub-tasks. "
|
||||
f"**HARD LIMIT: max {n} `task` calls per response.** "
|
||||
f"If >{n} sub-tasks, split into sequential batches of ≤{n}. Synthesize after ALL batches complete.\n"
|
||||
if subagent_enabled
|
||||
else ""
|
||||
)
|
||||
|
||||
# Add subagent thinking guidance if enabled
|
||||
subagent_thinking = (
|
||||
"- **DECOMPOSITION CHECK: Can this task be broken into 2+ parallel sub-tasks? If YES, COUNT them. "
|
||||
f"If count > {n}, you MUST plan batches of ≤{n} and only launch the FIRST batch now. "
|
||||
f"NEVER launch more than {n} `task` calls in one response.**\n"
|
||||
if subagent_enabled
|
||||
else ""
|
||||
)
|
||||
|
||||
# Get skills section
|
||||
skills_section = get_skills_prompt_section(available_skills)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
agent_name=agent_name or "DeerFlow 2.0",
|
||||
soul=get_agent_soul(agent_name),
|
||||
skills_section=skills_section,
|
||||
memory_context=memory_context,
|
||||
subagent_section=subagent_section,
|
||||
subagent_reminder=subagent_reminder,
|
||||
subagent_thinking=subagent_thinking,
|
||||
)
|
||||
|
||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||
44
backend/packages/harness/deerflow/agents/memory/__init__.py
Normal file
44
backend/packages/harness/deerflow/agents/memory/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Memory module for DeerFlow.
|
||||
|
||||
This module provides a global memory mechanism that:
|
||||
- Stores user context and conversation history in memory.json
|
||||
- Uses LLM to summarize and extract facts from conversations
|
||||
- Injects relevant memory into system prompts for personalized responses
|
||||
"""
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
FACT_EXTRACTION_PROMPT,
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
format_conversation_for_update,
|
||||
format_memory_for_injection,
|
||||
)
|
||||
from deerflow.agents.memory.queue import (
|
||||
ConversationContext,
|
||||
MemoryUpdateQueue,
|
||||
get_memory_queue,
|
||||
reset_memory_queue,
|
||||
)
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
get_memory_data,
|
||||
reload_memory_data,
|
||||
update_memory_from_conversation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Prompt utilities
|
||||
"MEMORY_UPDATE_PROMPT",
|
||||
"FACT_EXTRACTION_PROMPT",
|
||||
"format_memory_for_injection",
|
||||
"format_conversation_for_update",
|
||||
# Queue
|
||||
"ConversationContext",
|
||||
"MemoryUpdateQueue",
|
||||
"get_memory_queue",
|
||||
"reset_memory_queue",
|
||||
# Updater
|
||||
"MemoryUpdater",
|
||||
"get_memory_data",
|
||||
"reload_memory_data",
|
||||
"update_memory_from_conversation",
|
||||
]
|
||||
339
backend/packages/harness/deerflow/agents/memory/prompt.py
Normal file
339
backend/packages/harness/deerflow/agents/memory/prompt.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Prompt templates for memory update and injection."""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
TIKTOKEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
TIKTOKEN_AVAILABLE = False
|
||||
|
||||
# Prompt template for updating memory based on conversation
|
||||
MEMORY_UPDATE_PROMPT = """You are a memory management system. Your task is to analyze a conversation and update the user's memory profile.
|
||||
|
||||
Current Memory State:
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
New Conversation to Process:
|
||||
<conversation>
|
||||
{conversation}
|
||||
</conversation>
|
||||
|
||||
Instructions:
|
||||
1. Analyze the conversation for important information about the user
|
||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||
3. Update the memory sections as needed following the detailed length guidelines below
|
||||
|
||||
Memory Section Guidelines:
|
||||
|
||||
**User Context** (Current state - concise summaries):
|
||||
- workContext: Professional role, company, key projects, main technologies (2-3 sentences)
|
||||
Example: Core contributor, project names with metrics (16k+ stars), technical stack
|
||||
- personalContext: Languages, communication preferences, key interests (1-2 sentences)
|
||||
Example: Bilingual capabilities, specific interest areas, expertise domains
|
||||
- topOfMind: Multiple ongoing focus areas and priorities (3-5 sentences, detailed paragraph)
|
||||
Example: Primary project work, parallel technical investigations, ongoing learning/tracking
|
||||
Include: Active implementation work, troubleshooting issues, market/research interests
|
||||
Note: This captures SEVERAL concurrent focus areas, not just one task
|
||||
|
||||
**History** (Temporal context - rich paragraphs):
|
||||
- recentMonths: Detailed summary of recent activities (4-6 sentences or 1-2 paragraphs)
|
||||
Timeline: Last 1-3 months of interactions
|
||||
Include: Technologies explored, projects worked on, problems solved, interests demonstrated
|
||||
- earlierContext: Important historical patterns (3-5 sentences or 1 paragraph)
|
||||
Timeline: 3-12 months ago
|
||||
Include: Past projects, learning journeys, established patterns
|
||||
- longTermBackground: Persistent background and foundational context (2-4 sentences)
|
||||
Timeline: Overall/foundational information
|
||||
Include: Core expertise, longstanding interests, fundamental working style
|
||||
|
||||
**Facts Extraction**:
|
||||
- Extract specific, quantifiable details (e.g., "16k+ GitHub stars", "200+ datasets")
|
||||
- Include proper nouns (company names, project names, technology names)
|
||||
- Preserve technical terminology and version numbers
|
||||
- Categories:
|
||||
* preference: Tools, styles, approaches user prefers/dislikes
|
||||
* knowledge: Specific expertise, technologies mastered, domain knowledge
|
||||
* context: Background facts (job title, projects, locations, languages)
|
||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||
* goal: Stated objectives, learning targets, project ambitions
|
||||
- Confidence levels:
|
||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||
* 0.7-0.8: Strongly implied from actions/discussions
|
||||
* 0.5-0.6: Inferred patterns (use sparingly, only for clear patterns)
|
||||
|
||||
**What Goes Where**:
|
||||
- workContext: Current job, active projects, primary tech stack
|
||||
- personalContext: Languages, personality, interests outside direct work tasks
|
||||
- topOfMind: Multiple ongoing priorities and focus areas user cares about recently (gets updated most frequently)
|
||||
Should capture 3-5 concurrent themes: main work, side explorations, learning/tracking interests
|
||||
- recentMonths: Detailed account of recent technical explorations and work
|
||||
- earlierContext: Patterns from slightly older interactions still relevant
|
||||
- longTermBackground: Unchanging foundational facts about the user
|
||||
|
||||
**Multilingual Content**:
|
||||
- Preserve original language for proper nouns and company names
|
||||
- Keep technical terms in their original form (DeepSeek, LangGraph, etc.)
|
||||
- Note language capabilities in personalContext
|
||||
|
||||
Output Format (JSON):
|
||||
{{
|
||||
"user": {{
|
||||
"workContext": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"personalContext": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"topOfMind": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"history": {{
|
||||
"recentMonths": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"earlierContext": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"newFacts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
],
|
||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||
}}
|
||||
|
||||
Important Rules:
|
||||
- Only set shouldUpdate=true if there's meaningful new information
|
||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||
- Include specific metrics, version numbers, and proper nouns in facts
|
||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||
- Remove facts that are contradicted by new information
|
||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||
- For history sections, integrate new information chronologically into appropriate time period
|
||||
- Preserve technical accuracy - keep exact names of technologies, companies, projects
|
||||
- Focus on information useful for future interactions and personalization
|
||||
- IMPORTANT: Do NOT record file upload events in memory. Uploaded files are
|
||||
session-specific and ephemeral — they will not be accessible in future sessions.
|
||||
Recording upload events causes confusion in subsequent conversations.
|
||||
|
||||
Return ONLY valid JSON, no explanation or markdown."""
|
||||
|
||||
|
||||
# Prompt template for extracting facts from a single message
|
||||
FACT_EXTRACTION_PROMPT = """Extract factual information about the user from this message.
|
||||
|
||||
Message:
|
||||
{message}
|
||||
|
||||
Extract facts in this JSON format:
|
||||
{{
|
||||
"facts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
]
|
||||
}}
|
||||
|
||||
Categories:
|
||||
- preference: User preferences (likes/dislikes, styles, tools)
|
||||
- knowledge: User's expertise or knowledge areas
|
||||
- context: Background context (location, job, projects)
|
||||
- behavior: Behavioral patterns
|
||||
- goal: User's goals or objectives
|
||||
|
||||
Rules:
|
||||
- Only extract clear, specific facts
|
||||
- Confidence should reflect certainty (explicit statement = 0.9+, implied = 0.6-0.8)
|
||||
- Skip vague or temporary information
|
||||
|
||||
Return ONLY valid JSON."""
|
||||
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||
"""Count tokens in text using tiktoken.
|
||||
|
||||
Args:
|
||||
text: The text to count tokens for.
|
||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||
|
||||
Returns:
|
||||
The number of tokens in the text.
|
||||
"""
|
||||
if not TIKTOKEN_AVAILABLE:
|
||||
# Fallback to character-based estimation if tiktoken is not available
|
||||
return len(text) // 4
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
return len(encoding.encode(text))
|
||||
except Exception:
|
||||
# Fallback to character-based estimation on error
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||
"""Coerce a confidence-like value to a bounded float in [0, 1].
|
||||
|
||||
Non-finite values (NaN, inf, -inf) are treated as invalid and fall back
|
||||
to the default before clamping, preventing them from dominating ranking.
|
||||
The ``default`` parameter is assumed to be a finite value.
|
||||
"""
|
||||
try:
|
||||
confidence = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return max(0.0, min(1.0, default))
|
||||
if not math.isfinite(confidence):
|
||||
return max(0.0, min(1.0, default))
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||
"""Format memory data for injection into system prompt.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data dictionary.
|
||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||
|
||||
Returns:
|
||||
Formatted memory string for system prompt injection.
|
||||
"""
|
||||
if not memory_data:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
|
||||
# Format user context
|
||||
user_data = memory_data.get("user", {})
|
||||
if user_data:
|
||||
user_sections = []
|
||||
|
||||
work_ctx = user_data.get("workContext", {})
|
||||
if work_ctx.get("summary"):
|
||||
user_sections.append(f"Work: {work_ctx['summary']}")
|
||||
|
||||
personal_ctx = user_data.get("personalContext", {})
|
||||
if personal_ctx.get("summary"):
|
||||
user_sections.append(f"Personal: {personal_ctx['summary']}")
|
||||
|
||||
top_of_mind = user_data.get("topOfMind", {})
|
||||
if top_of_mind.get("summary"):
|
||||
user_sections.append(f"Current Focus: {top_of_mind['summary']}")
|
||||
|
||||
if user_sections:
|
||||
sections.append("User Context:\n" + "\n".join(f"- {s}" for s in user_sections))
|
||||
|
||||
# Format history
|
||||
history_data = memory_data.get("history", {})
|
||||
if history_data:
|
||||
history_sections = []
|
||||
|
||||
recent = history_data.get("recentMonths", {})
|
||||
if recent.get("summary"):
|
||||
history_sections.append(f"Recent: {recent['summary']}")
|
||||
|
||||
earlier = history_data.get("earlierContext", {})
|
||||
if earlier.get("summary"):
|
||||
history_sections.append(f"Earlier: {earlier['summary']}")
|
||||
|
||||
if history_sections:
|
||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||
|
||||
# Format facts (sorted by confidence; include as many as token budget allows)
|
||||
facts_data = memory_data.get("facts", [])
|
||||
if isinstance(facts_data, list) and facts_data:
|
||||
ranked_facts = sorted(
|
||||
(
|
||||
f
|
||||
for f in facts_data
|
||||
if isinstance(f, dict)
|
||||
and isinstance(f.get("content"), str)
|
||||
and f.get("content").strip()
|
||||
),
|
||||
key=lambda fact: _coerce_confidence(fact.get("confidence"), default=0.0),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Compute token count for existing sections once, then account
|
||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||
base_text = "\n\n".join(sections)
|
||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||
# Account for the separator between existing sections and the facts section.
|
||||
facts_header = "Facts:\n"
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||
running_tokens = base_tokens + separator_tokens
|
||||
|
||||
fact_lines: list[str] = []
|
||||
for fact in ranked_facts:
|
||||
content_value = fact.get("content")
|
||||
if not isinstance(content_value, str):
|
||||
continue
|
||||
content = content_value.strip()
|
||||
if not content:
|
||||
continue
|
||||
category = str(fact.get("category", "context")).strip() or "context"
|
||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
line_tokens = _count_tokens(line_text)
|
||||
|
||||
if running_tokens + line_tokens <= max_tokens:
|
||||
fact_lines.append(line)
|
||||
running_tokens += line_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if fact_lines:
|
||||
sections.append("Facts:\n" + "\n".join(fact_lines))
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
result = "\n\n".join(sections)
|
||||
|
||||
# Use accurate token counting with tiktoken
|
||||
token_count = _count_tokens(result)
|
||||
if token_count > max_tokens:
|
||||
# Truncate to fit within token limit
|
||||
# Estimate characters to remove based on token ratio
|
||||
char_per_token = len(result) / token_count
|
||||
target_chars = int(max_tokens * char_per_token * 0.95) # 95% to leave margin
|
||||
result = result[:target_chars] + "\n..."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_conversation_for_update(messages: list[Any]) -> str:
|
||||
"""Format conversation messages for memory update prompt.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
|
||||
Returns:
|
||||
Formatted conversation string.
|
||||
"""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = getattr(msg, "type", "unknown")
|
||||
content = getattr(msg, "content", str(msg))
|
||||
|
||||
# Handle content that might be a list (multimodal)
|
||||
if isinstance(content, list):
|
||||
text_parts = [p.get("text", "") for p in content if isinstance(p, dict) and "text" in p]
|
||||
content = " ".join(text_parts) if text_parts else str(content)
|
||||
|
||||
# Strip uploaded_files tags from human messages to avoid persisting
|
||||
# ephemeral file path info into long-term memory. Skip the turn entirely
|
||||
# when nothing remains after stripping (upload-only message).
|
||||
if role == "human":
|
||||
content = re.sub(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", "", str(content)).strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Truncate very long messages
|
||||
if len(str(content)) > 1000:
|
||||
content = str(content)[:1000] + "..."
|
||||
|
||||
if role == "human":
|
||||
lines.append(f"User: {content}")
|
||||
elif role == "ai":
|
||||
lines.append(f"Assistant: {content}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
195
backend/packages/harness/deerflow/agents/memory/queue.py
Normal file
195
backend/packages/harness/deerflow/agents/memory/queue.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Memory update queue with debounce mechanism."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""Context for a conversation to be processed for memory update."""
|
||||
|
||||
thread_id: str
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
"""Queue for memory updates with debounce mechanism.
|
||||
|
||||
This queue collects conversation contexts and processes them after
|
||||
a configurable debounce period. Multiple conversations received within
|
||||
the debounce window are batched together.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the memory update queue."""
|
||||
self._queue: list[ConversationContext] = []
|
||||
self._lock = threading.Lock()
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue.append(context)
|
||||
|
||||
# Reset or start the debounce timer
|
||||
self._reset_timer()
|
||||
|
||||
print(f"Memory update queued for thread {thread_id}, queue size: {len(self._queue)}")
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
"""Reset the debounce timer."""
|
||||
config = get_memory_config()
|
||||
|
||||
# Cancel existing timer if any
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
|
||||
# Start new timer
|
||||
self._timer = threading.Timer(
|
||||
config.debounce_seconds,
|
||||
self._process_queue,
|
||||
)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
|
||||
print(f"Memory update timer set for {config.debounce_seconds}s")
|
||||
|
||||
def _process_queue(self) -> None:
|
||||
"""Process all queued conversation contexts."""
|
||||
# Import here to avoid circular dependency
|
||||
from deerflow.agents.memory.updater import MemoryUpdater
|
||||
|
||||
with self._lock:
|
||||
if self._processing:
|
||||
# Already processing, reschedule
|
||||
self._reset_timer()
|
||||
return
|
||||
|
||||
if not self._queue:
|
||||
return
|
||||
|
||||
self._processing = True
|
||||
contexts_to_process = self._queue.copy()
|
||||
self._queue.clear()
|
||||
self._timer = None
|
||||
|
||||
print(f"Processing {len(contexts_to_process)} queued memory updates")
|
||||
|
||||
try:
|
||||
updater = MemoryUpdater()
|
||||
|
||||
for context in contexts_to_process:
|
||||
try:
|
||||
print(f"Updating memory for thread {context.thread_id}")
|
||||
success = updater.update_memory(
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
)
|
||||
if success:
|
||||
print(f"Memory updated successfully for thread {context.thread_id}")
|
||||
else:
|
||||
print(f"Memory update skipped/failed for thread {context.thread_id}")
|
||||
except Exception as e:
|
||||
print(f"Error updating memory for thread {context.thread_id}: {e}")
|
||||
|
||||
# Small delay between updates to avoid rate limiting
|
||||
if len(contexts_to_process) > 1:
|
||||
time.sleep(0.5)
|
||||
|
||||
finally:
|
||||
with self._lock:
|
||||
self._processing = False
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Force immediate processing of the queue.
|
||||
|
||||
This is useful for testing or graceful shutdown.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
|
||||
self._process_queue()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the queue without processing.
|
||||
|
||||
This is useful for testing.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
self._queue.clear()
|
||||
self._processing = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
"""Get the number of pending updates."""
|
||||
with self._lock:
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def is_processing(self) -> bool:
|
||||
"""Check if the queue is currently being processed."""
|
||||
with self._lock:
|
||||
return self._processing
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_memory_queue: MemoryUpdateQueue | None = None
|
||||
_queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_queue() -> MemoryUpdateQueue:
|
||||
"""Get the global memory update queue singleton.
|
||||
|
||||
Returns:
|
||||
The memory update queue instance.
|
||||
"""
|
||||
global _memory_queue
|
||||
with _queue_lock:
|
||||
if _memory_queue is None:
|
||||
_memory_queue = MemoryUpdateQueue()
|
||||
return _memory_queue
|
||||
|
||||
|
||||
def reset_memory_queue() -> None:
|
||||
"""Reset the global memory queue.
|
||||
|
||||
This is useful for testing.
|
||||
"""
|
||||
global _memory_queue
|
||||
with _queue_lock:
|
||||
if _memory_queue is not None:
|
||||
_memory_queue.clear()
|
||||
_memory_queue = None
|
||||
384
backend/packages/harness/deerflow/agents/memory/updater.py
Normal file
384
backend/packages/harness/deerflow/agents/memory/updater.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Memory updater for reading, writing, and updating memory data."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
format_conversation_for_update,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
|
||||
def _get_memory_file_path(agent_name: str | None = None) -> Path:
|
||||
"""Get the path to the memory file.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, returns the per-agent memory file path.
|
||||
If None, returns the global memory file path.
|
||||
|
||||
Returns:
|
||||
Path to the memory file.
|
||||
"""
|
||||
if agent_name is not None:
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
|
||||
config = get_memory_config()
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
# Absolute path: use as-is; relative path: resolve against base_dir
|
||||
return p if p.is_absolute() else get_paths().base_dir / p
|
||||
return get_paths().memory_file
|
||||
|
||||
|
||||
def _create_empty_memory() -> dict[str, Any]:
|
||||
"""Create an empty memory structure."""
|
||||
return {
|
||||
"version": "1.0",
|
||||
"lastUpdated": datetime.utcnow().isoformat() + "Z",
|
||||
"user": {
|
||||
"workContext": {"summary": "", "updatedAt": ""},
|
||||
"personalContext": {"summary": "", "updatedAt": ""},
|
||||
"topOfMind": {"summary": "", "updatedAt": ""},
|
||||
},
|
||||
"history": {
|
||||
"recentMonths": {"summary": "", "updatedAt": ""},
|
||||
"earlierContext": {"summary": "", "updatedAt": ""},
|
||||
"longTermBackground": {"summary": "", "updatedAt": ""},
|
||||
},
|
||||
"facts": [],
|
||||
}
|
||||
|
||||
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
_memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data (cached with file modification time check).
|
||||
|
||||
The cache is automatically invalidated if the memory file has been modified
|
||||
since the last load, ensuring fresh data is always returned.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
The memory data dictionary.
|
||||
"""
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
|
||||
# Get current file modification time
|
||||
try:
|
||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cached = _memory_cache.get(agent_name)
|
||||
|
||||
# Invalidate cache if file has been modified or doesn't exist
|
||||
if cached is None or cached[1] != current_mtime:
|
||||
memory_data = _load_memory_from_file(agent_name)
|
||||
_memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
return memory_data
|
||||
|
||||
return cached[0]
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, reloads per-agent memory. If None, reloads global memory.
|
||||
|
||||
Returns:
|
||||
The reloaded memory data dictionary.
|
||||
"""
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
memory_data = _load_memory_from_file(agent_name)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
_memory_cache[agent_name] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
|
||||
def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data from file.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory file. If None, loads global.
|
||||
|
||||
Returns:
|
||||
The memory data dictionary.
|
||||
"""
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
|
||||
if not file_path.exists():
|
||||
return _create_empty_memory()
|
||||
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"Failed to load memory file: {e}")
|
||||
return _create_empty_memory()
|
||||
|
||||
|
||||
# Matches sentences that describe a file-upload *event* rather than general
|
||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||
# such as "User works with CSV files" or "prefers PDF export".
|
||||
_UPLOAD_SENTENCE_RE = re.compile(
|
||||
r"[^.!?]*\b(?:"
|
||||
r"upload(?:ed|ing)?(?:\s+\w+){0,3}\s+(?:file|files?|document|documents?|attachment|attachments?)"
|
||||
r"|file\s+upload"
|
||||
r"|/mnt/user-data/uploads/"
|
||||
r"|<uploaded_files>"
|
||||
r")[^.!?]*[.!?]?\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _strip_upload_mentions_from_memory(memory_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove sentences about file uploads from all memory summaries and facts.
|
||||
|
||||
Uploaded files are session-scoped; persisting upload events in long-term
|
||||
memory causes the agent to search for non-existent files in future sessions.
|
||||
"""
|
||||
# Scrub summaries in user/history sections
|
||||
for section in ("user", "history"):
|
||||
section_data = memory_data.get(section, {})
|
||||
for _key, val in section_data.items():
|
||||
if isinstance(val, dict) and "summary" in val:
|
||||
cleaned = _UPLOAD_SENTENCE_RE.sub("", val["summary"]).strip()
|
||||
cleaned = re.sub(r" +", " ", cleaned)
|
||||
val["summary"] = cleaned
|
||||
|
||||
# Also remove any facts that describe upload events
|
||||
facts = memory_data.get("facts", [])
|
||||
if facts:
|
||||
memory_data["facts"] = [f for f in facts if not _UPLOAD_SENTENCE_RE.search(f.get("content", ""))]
|
||||
|
||||
return memory_data
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data to save.
|
||||
agent_name: If provided, saves to per-agent memory file. If None, saves to global.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
|
||||
try:
|
||||
# Ensure directory exists
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Update lastUpdated timestamp
|
||||
memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Write atomically using temp file
|
||||
temp_path = file_path.with_suffix(".tmp")
|
||||
with open(temp_path, "w", encoding="utf-8") as f:
|
||||
json.dump(memory_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Rename temp file to actual file (atomic on most systems)
|
||||
temp_path.replace(file_path)
|
||||
|
||||
# Update cache and file modification time
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
_memory_cache[agent_name] = (memory_data, mtime)
|
||||
|
||||
print(f"Memory saved to {file_path}")
|
||||
return True
|
||||
except OSError as e:
|
||||
print(f"Failed to save memory file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class MemoryUpdater:
|
||||
"""Updates memory using LLM based on conversation context."""
|
||||
|
||||
def __init__(self, model_name: str | None = None):
|
||||
"""Initialize the memory updater.
|
||||
|
||||
Args:
|
||||
model_name: Optional model name to use. If None, uses config or default.
|
||||
"""
|
||||
self._model_name = model_name
|
||||
|
||||
def _get_model(self):
|
||||
"""Get the model for memory updates."""
|
||||
config = get_memory_config()
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(agent_name)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
|
||||
if not conversation_text.strip():
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
model = self._get_model()
|
||||
response = model.invoke(prompt)
|
||||
response_text = str(response.content).strip()
|
||||
|
||||
# Parse response
|
||||
# Remove markdown code blocks if present
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
|
||||
# Apply updates
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
|
||||
# Strip file-upload mentions from all summaries before saving.
|
||||
# Uploaded files are session-scoped and won't exist in future sessions,
|
||||
# so recording upload events in long-term memory causes the agent to
|
||||
# try (and fail) to locate those files in subsequent conversations.
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return _save_memory_to_file(updated_memory, agent_name)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse LLM response for memory update: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Memory update failed: {e}")
|
||||
return False
|
||||
|
||||
def _apply_updates(
|
||||
self,
|
||||
current_memory: dict[str, Any],
|
||||
update_data: dict[str, Any],
|
||||
thread_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Apply LLM-generated updates to memory.
|
||||
|
||||
Args:
|
||||
current_memory: Current memory data.
|
||||
update_data: Updates from LLM.
|
||||
thread_id: Optional thread ID for tracking.
|
||||
|
||||
Returns:
|
||||
Updated memory data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
now = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Update user sections
|
||||
user_updates = update_data.get("user", {})
|
||||
for section in ["workContext", "personalContext", "topOfMind"]:
|
||||
section_data = user_updates.get(section, {})
|
||||
if section_data.get("shouldUpdate") and section_data.get("summary"):
|
||||
current_memory["user"][section] = {
|
||||
"summary": section_data["summary"],
|
||||
"updatedAt": now,
|
||||
}
|
||||
|
||||
# Update history sections
|
||||
history_updates = update_data.get("history", {})
|
||||
for section in ["recentMonths", "earlierContext", "longTermBackground"]:
|
||||
section_data = history_updates.get(section, {})
|
||||
if section_data.get("shouldUpdate") and section_data.get("summary"):
|
||||
current_memory["history"][section] = {
|
||||
"summary": section_data["summary"],
|
||||
"updatedAt": now,
|
||||
}
|
||||
|
||||
# Remove facts
|
||||
facts_to_remove = set(update_data.get("factsToRemove", []))
|
||||
if facts_to_remove:
|
||||
current_memory["facts"] = [f for f in current_memory.get("facts", []) if f.get("id") not in facts_to_remove]
|
||||
|
||||
# Add new facts
|
||||
new_facts = update_data.get("newFacts", [])
|
||||
for fact in new_facts:
|
||||
confidence = fact.get("confidence", 0.5)
|
||||
if confidence >= config.fact_confidence_threshold:
|
||||
fact_entry = {
|
||||
"id": f"fact_{uuid.uuid4().hex[:8]}",
|
||||
"content": fact.get("content", ""),
|
||||
"category": fact.get("category", "context"),
|
||||
"confidence": confidence,
|
||||
"createdAt": now,
|
||||
"source": thread_id or "unknown",
|
||||
}
|
||||
current_memory["facts"].append(fact_entry)
|
||||
|
||||
# Enforce max facts limit
|
||||
if len(current_memory["facts"]) > config.max_facts:
|
||||
# Sort by confidence and keep top ones
|
||||
current_memory["facts"] = sorted(
|
||||
current_memory["facts"],
|
||||
key=lambda f: f.get("confidence", 0),
|
||||
reverse=True,
|
||||
)[: config.max_facts]
|
||||
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name)
|
||||
@@ -0,0 +1,173 @@
|
||||
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.graph import END
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
class ClarificationMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
|
||||
|
||||
When the model calls the `ask_clarification` tool, this middleware:
|
||||
1. Intercepts the tool call before execution
|
||||
2. Extracts the clarification question and metadata
|
||||
3. Formats a user-friendly message
|
||||
4. Returns a Command that interrupts execution and presents the question
|
||||
5. Waits for user response before continuing
|
||||
|
||||
This replaces the tool-based approach where clarification continued the conversation flow.
|
||||
"""
|
||||
|
||||
state_schema = ClarificationMiddlewareState
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text contains Chinese characters
|
||||
"""
|
||||
return any("\u4e00" <= char <= "\u9fff" for char in text)
|
||||
|
||||
def _format_clarification_message(self, args: dict) -> str:
|
||||
"""Format the clarification arguments into a user-friendly message.
|
||||
|
||||
Args:
|
||||
args: The tool call arguments containing clarification details
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
question = args.get("question", "")
|
||||
clarification_type = args.get("clarification_type", "missing_info")
|
||||
context = args.get("context")
|
||||
options = args.get("options", [])
|
||||
|
||||
# Type-specific icons
|
||||
type_icons = {
|
||||
"missing_info": "❓",
|
||||
"ambiguous_requirement": "🤔",
|
||||
"approach_choice": "🔀",
|
||||
"risk_confirmation": "⚠️",
|
||||
"suggestion": "💡",
|
||||
}
|
||||
|
||||
icon = type_icons.get(clarification_type, "❓")
|
||||
|
||||
# Build the message naturally
|
||||
message_parts = []
|
||||
|
||||
# Add icon and question together for a more natural flow
|
||||
if context:
|
||||
# If there's context, present it first as background
|
||||
message_parts.append(f"{icon} {context}")
|
||||
message_parts.append(f"\n{question}")
|
||||
else:
|
||||
# Just the question with icon
|
||||
message_parts.append(f"{icon} {question}")
|
||||
|
||||
# Add options in a cleaner format
|
||||
if options and len(options) > 0:
|
||||
message_parts.append("") # blank line for spacing
|
||||
for i, option in enumerate(options, 1):
|
||||
message_parts.append(f" {i}. {option}")
|
||||
|
||||
return "\n".join(message_parts)
|
||||
|
||||
def _handle_clarification(self, request: ToolCallRequest) -> Command:
|
||||
"""Handle clarification request and return command to interrupt execution.
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Extract clarification arguments
|
||||
args = request.tool_call.get("args", {})
|
||||
question = args.get("question", "")
|
||||
|
||||
print("[ClarificationMiddleware] Intercepted clarification request")
|
||||
print(f"[ClarificationMiddleware] Question: {question}")
|
||||
|
||||
# Format the clarification message
|
||||
formatted_message = self._format_clarification_message(args)
|
||||
|
||||
# Get the tool call ID
|
||||
tool_call_id = request.tool_call.get("id", "")
|
||||
|
||||
# Create a ToolMessage with the formatted question
|
||||
# This will be added to the message history
|
||||
tool_message = ToolMessage(
|
||||
content=formatted_message,
|
||||
tool_call_id=tool_call_id,
|
||||
name="ask_clarification",
|
||||
)
|
||||
|
||||
# Return a Command that:
|
||||
# 1. Adds the formatted tool message
|
||||
# 2. Interrupts execution by going to __end__
|
||||
# Note: We don't add an extra AIMessage here - the frontend will detect
|
||||
# and display ask_clarification tool messages directly
|
||||
return Command(
|
||||
update={"messages": [tool_message]},
|
||||
goto=END,
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (async version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler (async)
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return await handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Middleware to fix dangling tool calls in message history.
|
||||
|
||||
A dangling tool call occurs when an AIMessage contains tool_calls but there are
|
||||
no corresponding ToolMessages in the history (e.g., due to user interruption or
|
||||
request cancellation). This causes LLM errors due to incomplete message format.
|
||||
|
||||
This middleware intercepts the model call to detect and patch such gaps by
|
||||
inserting synthetic ToolMessages with an error indicator immediately after the
|
||||
AIMessage that made the tool calls, ensuring correct message ordering.
|
||||
|
||||
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
|
||||
at the correct positions (immediately after each dangling AIMessage), not appended
|
||||
to the end of the message list as before_model + add_messages reducer would do.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
||||
|
||||
Scans the message history for AIMessages whose tool_calls lack corresponding
|
||||
ToolMessages, and injects synthetic error responses immediately after the
|
||||
offending AIMessage so the LLM receives a well-formed conversation.
|
||||
"""
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
"""
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
patched_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content="[Tool call was interrupted and did not return a result.]",
|
||||
tool_call_id=tc_id,
|
||||
name=tc.get("name", "unknown"),
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
patched_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
patched = self._build_patched_messages(request.messages)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return handler(request)
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
patched = self._build_patched_messages(request.messages)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return await handler(request)
|
||||
@@ -0,0 +1,227 @@
|
||||
"""Middleware to detect and break repetitive tool call loops.
|
||||
|
||||
P0 safety: prevents the agent from calling the same tool with the same
|
||||
arguments indefinitely until the recursion limit kills the run.
|
||||
|
||||
Detection strategy:
|
||||
1. After each model response, hash the tool calls (name + args).
|
||||
2. Track recent hashes in a sliding window.
|
||||
3. If the same hash appears >= warn_threshold times, inject a
|
||||
"you are repeating yourself — wrap up" system message (once per hash).
|
||||
4. If it appears >= hard_limit times, strip all tool_calls from the
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
||||
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
||||
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
|
||||
|
||||
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Deterministic hash of a set of tool calls (name + args).
|
||||
|
||||
This is intended to be order-independent: the same multiset of tool calls
|
||||
should always produce the same hash, regardless of their input order.
|
||||
"""
|
||||
# First normalize each tool call to a minimal (name, args) structure.
|
||||
normalized: list[dict] = []
|
||||
for tc in tool_calls:
|
||||
normalized.append(
|
||||
{
|
||||
"name": tc.get("name", ""),
|
||||
"args": tc.get("args", {}),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by both name and a deterministic serialization of args so that
|
||||
# permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort(
|
||||
key=lambda tc: (
|
||||
tc["name"],
|
||||
json.dumps(tc["args"], sort_keys=True, default=str),
|
||||
)
|
||||
)
|
||||
blob = json.dumps(normalized, sort_keys=True, default=str)
|
||||
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
_WARNING_MSG = (
|
||||
"[LOOP DETECTED] You are repeating the same tool calls. "
|
||||
"Stop calling tools and produce your final answer now. "
|
||||
"If you cannot complete the task, summarize what you accomplished so far."
|
||||
)
|
||||
|
||||
_HARD_STOP_MSG = (
|
||||
"[FORCED STOP] Repeated tool calls exceeded the safety limit. "
|
||||
"Producing final answer with results collected so far."
|
||||
)
|
||||
|
||||
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
hard_limit: Number of identical tool call sets before stripping
|
||||
tool_calls entirely. Default: 5.
|
||||
window_size: Size of the sliding window for tracking calls.
|
||||
Default: 20.
|
||||
max_tracked_threads: Maximum number of threads to track before
|
||||
evicting the least recently used. Default: 100.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
||||
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
||||
window_size: int = _DEFAULT_WINDOW_SIZE,
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
self.hard_limit = hard_limit
|
||||
self.window_size = window_size
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return "default"
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
while len(self._history) > self.max_tracked_threads:
|
||||
evicted_id, _ = self._history.popitem(last=False)
|
||||
self._warned.pop(evicted_id, None)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
Returns:
|
||||
(warning_message_or_none, should_hard_stop)
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None, False
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None, False
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None, False
|
||||
|
||||
thread_id = self._get_thread_id(runtime)
|
||||
call_hash = _hash_tool_calls(tool_calls)
|
||||
|
||||
with self._lock:
|
||||
# Touch / create entry (move to end for LRU)
|
||||
if thread_id in self._history:
|
||||
self._history.move_to_end(thread_id)
|
||||
else:
|
||||
self._history[thread_id] = []
|
||||
self._evict_if_needed()
|
||||
|
||||
history = self._history[thread_id]
|
||||
history.append(call_hash)
|
||||
if len(history) > self.window_size:
|
||||
history[:] = history[-self.window_size:]
|
||||
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
if count >= self.hard_limit:
|
||||
logger.error(
|
||||
"Loop hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _HARD_STOP_MSG, True
|
||||
|
||||
if count >= self.warn_threshold:
|
||||
warned = self._warned[thread_id]
|
||||
if call_hash not in warned:
|
||||
warned.add(call_hash)
|
||||
logger.warning(
|
||||
"Repetitive tool calls detected — injecting warning",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _WARNING_MSG, False
|
||||
# Warning already injected for this hash — suppress
|
||||
return None, False
|
||||
|
||||
return None, False
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
if hard_stop:
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
stripped_msg = last_msg.model_copy(update={
|
||||
"tool_calls": [],
|
||||
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
|
||||
})
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# Inject a system message warning the model
|
||||
return {"messages": [SystemMessage(content=warning)]}
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
"""Clear tracking state. If thread_id given, clear only that thread."""
|
||||
with self._lock:
|
||||
if thread_id:
|
||||
self._history.pop(thread_id, None)
|
||||
self._warned.pop(thread_id, None)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Middleware for memory mechanism."""
|
||||
|
||||
import re
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
This filters out:
|
||||
- Tool messages (intermediate tool call results)
|
||||
- AI messages with tool_calls (intermediate steps, not final responses)
|
||||
- The <uploaded_files> block injected by UploadsMiddleware into human messages
|
||||
(file paths are session-scoped and must not persist in long-term memory).
|
||||
The user's actual question is preserved; only turns whose content is entirely
|
||||
the upload block (nothing remains after stripping) are dropped along with
|
||||
their paired assistant response.
|
||||
|
||||
Only keeps:
|
||||
- Human messages (with the ephemeral upload block removed)
|
||||
- AI messages without tool_calls (final assistant responses), unless the
|
||||
paired human turn was upload-only and had no real user text.
|
||||
|
||||
Args:
|
||||
messages: List of all conversation messages.
|
||||
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
||||
content_str = str(content)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
if not stripped:
|
||||
# Nothing left — the entire turn was upload bookkeeping;
|
||||
# skip it and the paired assistant response.
|
||||
skip_next_ai = True
|
||||
continue
|
||||
# Rebuild the message with cleaned content so the user's question
|
||||
# is still available for memory summarisation.
|
||||
from copy import copy
|
||||
|
||||
clean_msg = copy(msg)
|
||||
clean_msg.content = stripped
|
||||
filtered.append(clean_msg)
|
||||
skip_next_ai = False
|
||||
else:
|
||||
filtered.append(msg)
|
||||
skip_next_ai = False
|
||||
elif msg_type == "ai":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
if skip_next_ai:
|
||||
skip_next_ai = False
|
||||
continue
|
||||
filtered.append(msg)
|
||||
# Skip tool messages and AI messages with tool_calls
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
This middleware:
|
||||
1. After each agent execution, queues the conversation for memory update
|
||||
2. Only includes user inputs and final assistant responses (ignores tool calls)
|
||||
3. The queue uses debouncing to batch multiple updates together
|
||||
4. Memory is updated asynchronously via LLM summarization
|
||||
"""
|
||||
|
||||
state_schema = MemoryMiddlewareState
|
||||
|
||||
def __init__(self, agent_name: str | None = None):
|
||||
"""Initialize the MemoryMiddleware.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Queue conversation for memory update after agent completes.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
# Get thread ID from runtime context
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if not thread_id:
|
||||
print("MemoryMiddleware: No thread_id in context, skipping memory update")
|
||||
return None
|
||||
|
||||
# Get messages from state
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
print("MemoryMiddleware: No messages in state, skipping memory update")
|
||||
return None
|
||||
|
||||
# Filter to only keep user inputs and final assistant responses
|
||||
filtered_messages = _filter_messages_for_memory(messages)
|
||||
|
||||
# Only queue if there's meaningful conversation
|
||||
# At minimum need one user message and one assistant response
|
||||
user_messages = [m for m in filtered_messages if getattr(m, "type", None) == "human"]
|
||||
assistant_messages = [m for m in filtered_messages if getattr(m, "type", None) == "ai"]
|
||||
|
||||
if not user_messages or not assistant_messages:
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
queue = get_memory_queue()
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid range for max_concurrent_subagents
|
||||
MIN_SUBAGENT_LIMIT = 2
|
||||
MAX_SUBAGENT_LIMIT = 4
|
||||
|
||||
|
||||
def _clamp_subagent_limit(value: int) -> int:
|
||||
"""Clamp subagent limit to valid range [2, 4]."""
|
||||
return max(MIN_SUBAGENT_LIMIT, min(MAX_SUBAGENT_LIMIT, value))
|
||||
|
||||
|
||||
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Truncates excess 'task' tool calls from a single model response.
|
||||
|
||||
When an LLM generates more than max_concurrent parallel task tool calls
|
||||
in one response, this middleware keeps only the first max_concurrent and
|
||||
discards the rest. This is more reliable than prompt-based limits.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed.
|
||||
Defaults to MAX_CONCURRENT_SUBAGENTS (3). Clamped to [2, 4].
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int = MAX_CONCURRENT_SUBAGENTS):
|
||||
super().__init__()
|
||||
self.max_concurrent = _clamp_subagent_limit(max_concurrent)
|
||||
|
||||
def _truncate_task_calls(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
# Count task tool calls
|
||||
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
|
||||
if len(task_indices) <= self.max_concurrent:
|
||||
return None
|
||||
|
||||
# Build set of indices to drop (excess task calls beyond the limit)
|
||||
indices_to_drop = set(task_indices[self.max_concurrent :])
|
||||
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
|
||||
|
||||
dropped_count = len(indices_to_drop)
|
||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
@@ -0,0 +1,90 @@
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
|
||||
|
||||
class ThreadDataMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
"""Create thread data directories for each thread execution.
|
||||
|
||||
Creates the following directory structure:
|
||||
- {base_dir}/threads/{thread_id}/user-data/workspace
|
||||
- {base_dir}/threads/{thread_id}/user-data/uploads
|
||||
- {base_dir}/threads/{thread_id}/user-data/outputs
|
||||
|
||||
Lifecycle Management:
|
||||
- With lazy_init=True (default): Only compute paths, directories created on-demand
|
||||
- With lazy_init=False: Eagerly create directories in before_agent()
|
||||
"""
|
||||
|
||||
state_schema = ThreadDataMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||||
lazy_init: If True, defer directory creation until needed.
|
||||
If False, create directories eagerly in before_agent().
|
||||
Default is True for optimal performance.
|
||||
"""
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
return {
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
self._paths.ensure_thread_dirs(thread_id)
|
||||
return self._get_thread_paths(thread_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in the context")
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
print(f"Created thread data directories for thread {thread_id}")
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
**paths,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Middleware for automatic thread title generation."""
|
||||
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
|
||||
class TitleMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
title: NotRequired[str | None]
|
||||
|
||||
|
||||
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
"""Automatically generate a title for the thread after the first user message."""
|
||||
|
||||
state_schema = TitleMiddlewareState
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = get_title_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
# Check if thread already has a title in state
|
||||
if state.get("title"):
|
||||
return False
|
||||
|
||||
# Check if this is the first turn (has at least one user message and one assistant response)
|
||||
messages = state.get("messages", [])
|
||||
if len(messages) < 2:
|
||||
return False
|
||||
|
||||
# Count user and assistant messages
|
||||
user_messages = [m for m in messages if m.type == "human"]
|
||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||
|
||||
# Generate title after first complete exchange
|
||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||
|
||||
async def _generate_title(self, state: TitleMiddlewareState) -> str:
|
||||
"""Generate a concise title based on the conversation."""
|
||||
config = get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Get first user message and first assistant response
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||
|
||||
# Ensure content is string (LangChain messages can have list content)
|
||||
user_msg = str(user_msg_content) if user_msg_content else ""
|
||||
assistant_msg = str(assistant_msg_content) if assistant_msg_content else ""
|
||||
|
||||
# Use a lightweight model to generate title
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
|
||||
prompt = config.prompt_template.format(
|
||||
max_words=config.max_words,
|
||||
user_msg=user_msg[:500],
|
||||
assistant_msg=assistant_msg[:500],
|
||||
)
|
||||
|
||||
try:
|
||||
response = await model.ainvoke(prompt)
|
||||
# Ensure response content is string
|
||||
title_content = str(response.content) if response.content else ""
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
# Limit to max characters
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
except Exception as e:
|
||||
print(f"Failed to generate title: {e}")
|
||||
# Fallback: use first part of user message (by character count)
|
||||
fallback_chars = min(config.max_chars, 50) # Use max_chars or 50, whichever is smaller
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Generate and set thread title after the first agent response."""
|
||||
if self._should_generate_title(state):
|
||||
title = await self._generate_title(state)
|
||||
print(f"Generated thread title: {title}")
|
||||
|
||||
# Store title in state (will be persisted by checkpointer if configured)
|
||||
return {"title": title}
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||||
|
||||
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||||
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||||
active context window. This middleware detects that situation and injects a
|
||||
reminder message so the model still knows about the outstanding todo list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _todos_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("name") == "write_todos":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _reminder_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _format_todos(todos: list[Todo]) -> str:
|
||||
"""Format a list of Todo items into a human-readable string."""
|
||||
lines: list[str] = []
|
||||
for todo in todos:
|
||||
status = todo.get("status", "pending")
|
||||
content = todo.get("content", "")
|
||||
lines.append(f"- [{status}] {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
When the original `write_todos` tool call has been truncated from the message
|
||||
history (e.g., after summarization), the model loses awareness of the current
|
||||
todo list. This middleware detects that gap in `before_model` / `abefore_model`
|
||||
and injects a reminder message so the model can continue tracking progress.
|
||||
"""
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
) -> dict[str, Any] | None:
|
||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||
if not todos:
|
||||
return None
|
||||
|
||||
messages = state.get("messages") or []
|
||||
if _todos_in_messages(messages):
|
||||
# write_todos is still visible in context — nothing to do.
|
||||
return None
|
||||
|
||||
if _reminder_in_messages(messages):
|
||||
# A reminder was already injected and hasn't been truncated yet.
|
||||
return None
|
||||
|
||||
# The todo list exists in state but the original write_todos call is gone.
|
||||
# Inject a reminder as a HumanMessage so the model stays aware.
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
"but it is still active. Here is the current state:\n\n"
|
||||
f"{formatted}\n\n"
|
||||
"Continue tracking and updating this todo list as you work. "
|
||||
"Call `write_todos` whenever the status of any item changes.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"messages": [reminder]}
|
||||
|
||||
@override
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Tool error handling middleware and shared runtime middleware builders."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||
|
||||
|
||||
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Convert tool exceptions into error ToolMessages so the run can continue."""
|
||||
|
||||
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
|
||||
tool_name = str(request.tool_call.get("name") or "unknown_tool")
|
||||
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
|
||||
detail = str(exc).strip() or exc.__class__.__name__
|
||||
if len(detail) > 500:
|
||||
detail = detail[:497] + "..."
|
||||
|
||||
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
|
||||
def _build_runtime_middlewares(
|
||||
*,
|
||||
include_uploads: bool,
|
||||
include_dangling_tool_call_patch: bool,
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
middlewares: list[AgentMiddleware] = [
|
||||
ThreadDataMiddleware(lazy_init=lazy_init),
|
||||
SandboxMiddleware(lazy_init=lazy_init),
|
||||
]
|
||||
|
||||
if include_uploads:
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
|
||||
middlewares.insert(1, UploadsMiddleware())
|
||||
|
||||
if include_dangling_tool_call_patch:
|
||||
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(ToolErrorHandlingMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=True,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
|
||||
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=False,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
@@ -0,0 +1,204 @@
|
||||
"""Middleware to inject uploaded files information into agent context."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
|
||||
|
||||
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
"""Middleware to inject uploaded files information into the agent context.
|
||||
|
||||
Reads file metadata from the current message's additional_kwargs.files
|
||||
(set by the frontend after upload) and prepends an <uploaded_files> block
|
||||
to the last human message so the model knows which files are available.
|
||||
"""
|
||||
|
||||
state_schema = UploadsMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
"""
|
||||
lines = ["<uploaded_files>"]
|
||||
|
||||
lines.append("The following files were uploaded in this message:")
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
|
||||
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
|
||||
"""Extract file info from message additional_kwargs.files.
|
||||
|
||||
The frontend sends uploaded file metadata in additional_kwargs.files
|
||||
after a successful upload. Each entry has: filename, size (bytes),
|
||||
path (virtual path), status.
|
||||
|
||||
Args:
|
||||
message: The human message to inspect.
|
||||
uploads_dir: Physical uploads directory used to verify file existence.
|
||||
When provided, entries whose files no longer exist are skipped.
|
||||
|
||||
Returns:
|
||||
List of file dicts with virtual paths, or None if the field is absent or empty.
|
||||
"""
|
||||
kwargs_files = (message.additional_kwargs or {}).get("files")
|
||||
if not isinstance(kwargs_files, list) or not kwargs_files:
|
||||
return None
|
||||
|
||||
files = []
|
||||
for f in kwargs_files:
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
filename = f.get("filename") or ""
|
||||
if not filename or Path(filename).name != filename:
|
||||
continue
|
||||
if uploads_dir is not None and not (uploads_dir / filename).is_file():
|
||||
continue
|
||||
files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": int(f.get("size") or 0),
|
||||
"path": f"/mnt/user-data/uploads/{filename}",
|
||||
"extension": Path(filename).suffix,
|
||||
}
|
||||
)
|
||||
return files if files else None
|
||||
|
||||
@override
|
||||
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject uploaded files information before agent execution.
|
||||
|
||||
New files come from the current message's additional_kwargs.files.
|
||||
Historical files are scanned from the thread's uploads directory,
|
||||
excluding the new ones.
|
||||
|
||||
Prepends <uploaded_files> context to the last human message content.
|
||||
The original additional_kwargs (including files metadata) is preserved
|
||||
on the updated message so the frontend can read it from the stream.
|
||||
|
||||
Args:
|
||||
state: Current agent state.
|
||||
runtime: Runtime context containing thread_id.
|
||||
|
||||
Returns:
|
||||
State updates including uploaded files list.
|
||||
"""
|
||||
messages = list(state.get("messages", []))
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message_index = len(messages) - 1
|
||||
last_message = messages[last_message_index]
|
||||
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||
|
||||
# Collect historical files from the uploads directory (all except the new ones)
|
||||
new_filenames = {f["filename"] for f in new_files}
|
||||
historical_files: list[dict] = []
|
||||
if uploads_dir and uploads_dir.exists():
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
}
|
||||
)
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
|
||||
|
||||
# Create files message and prepend to the last human message content
|
||||
files_message = self._create_files_message(new_files, historical_files)
|
||||
|
||||
# Extract original content - handle both string and list formats
|
||||
original_content = ""
|
||||
if isinstance(last_message.content, str):
|
||||
original_content = last_message.content
|
||||
elif isinstance(last_message.content, list):
|
||||
text_parts = []
|
||||
for block in last_message.content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
original_content = "\n".join(text_parts)
|
||||
|
||||
# Create new message with combined content.
|
||||
# Preserve additional_kwargs (including files metadata) so the frontend
|
||||
# can read structured file info from the streamed message.
|
||||
updated_message = HumanMessage(
|
||||
content=f"{files_message}\n\n{original_content}",
|
||||
id=last_message.id,
|
||||
additional_kwargs=last_message.additional_kwargs,
|
||||
)
|
||||
|
||||
messages[last_message_index] = updated_message
|
||||
|
||||
return {
|
||||
"uploaded_files": new_files,
|
||||
"messages": messages,
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Middleware for injecting image details into conversation before LLM call."""
|
||||
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ViewedImageData
|
||||
|
||||
|
||||
class ViewImageMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
viewed_images: NotRequired[dict[str, ViewedImageData] | None]
|
||||
|
||||
|
||||
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
||||
"""Injects image details as a human message before LLM calls when view_image tools have completed.
|
||||
|
||||
This middleware:
|
||||
1. Runs before each LLM call
|
||||
2. Checks if the last assistant message contains view_image tool calls
|
||||
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
|
||||
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
|
||||
5. Adds the message to state so the LLM can see and analyze the images
|
||||
|
||||
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
|
||||
without requiring explicit user prompts to describe the images.
|
||||
"""
|
||||
|
||||
state_schema = ViewImageMiddlewareState
|
||||
|
||||
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
|
||||
"""Get the last assistant message from the message list.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Last AIMessage or None if not found
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def _has_view_image_tool(self, message: AIMessage) -> bool:
|
||||
"""Check if the assistant message contains view_image tool calls.
|
||||
|
||||
Args:
|
||||
message: Assistant message to check
|
||||
|
||||
Returns:
|
||||
True if message contains view_image tool calls
|
||||
"""
|
||||
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
|
||||
|
||||
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
|
||||
"""Check if all tool calls in the assistant message have been completed.
|
||||
|
||||
Args:
|
||||
messages: List of all messages
|
||||
assistant_msg: The assistant message containing tool calls
|
||||
|
||||
Returns:
|
||||
True if all tool calls have corresponding ToolMessages
|
||||
"""
|
||||
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
|
||||
return False
|
||||
|
||||
# Get all tool call IDs from the assistant message
|
||||
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
|
||||
|
||||
# Find the index of the assistant message
|
||||
try:
|
||||
assistant_idx = messages.index(assistant_msg)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Get all ToolMessages after the assistant message
|
||||
completed_tool_ids = set()
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id:
|
||||
completed_tool_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if all tool calls have been completed
|
||||
return tool_call_ids.issubset(completed_tool_ids)
|
||||
|
||||
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
|
||||
"""Create a formatted message with all viewed image details.
|
||||
|
||||
Args:
|
||||
state: Current state containing viewed_images
|
||||
|
||||
Returns:
|
||||
List of content blocks (text and images) for the HumanMessage
|
||||
"""
|
||||
viewed_images = state.get("viewed_images", {})
|
||||
if not viewed_images:
|
||||
return ["No images have been viewed."]
|
||||
|
||||
# Build the message with image information
|
||||
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
|
||||
|
||||
for image_path, image_data in viewed_images.items():
|
||||
mime_type = image_data.get("mime_type", "unknown")
|
||||
base64_data = image_data.get("base64", "")
|
||||
|
||||
# Add text description
|
||||
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
|
||||
|
||||
# Add the actual image data so LLM can "see" it
|
||||
if base64_data:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
|
||||
"""Determine if we should inject an image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
True if we should inject the message
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# Get the last assistant message
|
||||
last_assistant_msg = self._get_last_assistant_message(messages)
|
||||
if not last_assistant_msg:
|
||||
return False
|
||||
|
||||
# Check if it has view_image tool calls
|
||||
if not self._has_view_image_tool(last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if all tools have been completed
|
||||
if not self._all_tools_completed(messages, last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if we've already added an image details message
|
||||
# Look for a human message after the last assistant message that contains image details
|
||||
assistant_idx = messages.index(last_assistant_msg)
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content_str = str(msg.content)
|
||||
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
|
||||
# Already added, don't add again
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
|
||||
"""Internal helper to inject image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
if not self._should_inject_image_message(state):
|
||||
return None
|
||||
|
||||
# Create the image details message with text and image content
|
||||
image_content = self._create_image_details_message(state)
|
||||
|
||||
# Create a new human message with mixed content (text + images)
|
||||
human_msg = HumanMessage(content=image_content)
|
||||
|
||||
print("[ViewImageMiddleware] Injecting image details message with images before LLM call")
|
||||
|
||||
# Return state update with the new message
|
||||
return {"messages": [human_msg]}
|
||||
|
||||
@override
|
||||
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (sync version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
|
||||
@override
|
||||
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (async version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
55
backend/packages/harness/deerflow/agents/thread_state.py
Normal file
55
backend/packages/harness/deerflow/agents/thread_state.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents import AgentState
|
||||
|
||||
|
||||
class SandboxState(TypedDict):
|
||||
sandbox_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class ThreadDataState(TypedDict):
|
||||
workspace_path: NotRequired[str | None]
|
||||
uploads_path: NotRequired[str | None]
|
||||
outputs_path: NotRequired[str | None]
|
||||
|
||||
|
||||
class ViewedImageData(TypedDict):
|
||||
base64: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||
if existing is None:
|
||||
return new or []
|
||||
if new is None:
|
||||
return existing
|
||||
# Use dict.fromkeys to deduplicate while preserving order
|
||||
return list(dict.fromkeys(existing + new))
|
||||
|
||||
|
||||
def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]:
|
||||
"""Reducer for viewed_images dict - merges image dictionaries.
|
||||
|
||||
Special case: If new is an empty dict {}, it clears the existing images.
|
||||
This allows middlewares to clear the viewed_images state after processing.
|
||||
"""
|
||||
if existing is None:
|
||||
return new or {}
|
||||
if new is None:
|
||||
return existing
|
||||
# Special case: empty dict means clear all viewed images
|
||||
if len(new) == 0:
|
||||
return {}
|
||||
# Merge dictionaries, new values override existing ones for same keys
|
||||
return {**existing, **new}
|
||||
|
||||
|
||||
class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
artifacts: Annotated[list[str], merge_artifacts]
|
||||
todos: NotRequired[list | None]
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
|
||||
907
backend/packages/harness/deerflow/client.py
Normal file
907
backend/packages/harness/deerflow/client.py
Normal file
@@ -0,0 +1,907 @@
|
||||
"""DeerFlowClient — Embedded Python client for DeerFlow agent system.
|
||||
|
||||
Provides direct programmatic access to DeerFlow's agent capabilities
|
||||
without requiring LangGraph Server or Gateway API processes.
|
||||
|
||||
Usage:
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
client = DeerFlowClient()
|
||||
response = client.chat("Analyze this paper for me", thread_id="my-thread")
|
||||
print(response)
|
||||
|
||||
# Streaming
|
||||
for event in client.stream("hello"):
|
||||
print(event)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.agent import _build_middlewares
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.app_config import get_app_config, reload_app_config
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEvent:
|
||||
"""A single event from the streaming agent response.
|
||||
|
||||
Event types align with the LangGraph SSE protocol:
|
||||
- ``"values"``: Full state snapshot (title, messages, artifacts).
|
||||
- ``"messages-tuple"``: Per-message update (AI text, tool calls, tool results).
|
||||
- ``"end"``: Stream finished.
|
||||
|
||||
Attributes:
|
||||
type: Event type.
|
||||
data: Event payload. Contents vary by type.
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DeerFlowClient:
|
||||
"""Embedded Python client for DeerFlow agent system.
|
||||
|
||||
Provides direct programmatic access to DeerFlow's agent capabilities
|
||||
without requiring LangGraph Server or Gateway API processes.
|
||||
|
||||
Note:
|
||||
Multi-turn conversations require a ``checkpointer``. Without one,
|
||||
each ``stream()`` / ``chat()`` call is stateless — ``thread_id``
|
||||
is only used for file isolation (uploads / artifacts).
|
||||
|
||||
The system prompt (including date, memory, and skills context) is
|
||||
generated when the internal agent is first created and cached until
|
||||
the configuration key changes. Call :meth:`reset_agent` to force
|
||||
a refresh in long-running processes.
|
||||
|
||||
Example::
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
client = DeerFlowClient()
|
||||
|
||||
# Simple one-shot
|
||||
print(client.chat("hello"))
|
||||
|
||||
# Streaming
|
||||
for event in client.stream("hello"):
|
||||
print(event.type, event.data)
|
||||
|
||||
# Configuration queries
|
||||
print(client.list_models())
|
||||
print(client.list_skills())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str | None = None,
|
||||
checkpointer=None,
|
||||
*,
|
||||
model_name: str | None = None,
|
||||
thinking_enabled: bool = True,
|
||||
subagent_enabled: bool = False,
|
||||
plan_mode: bool = False,
|
||||
):
|
||||
"""Initialize the client.
|
||||
|
||||
Loads configuration but defers agent creation to first use.
|
||||
|
||||
Args:
|
||||
config_path: Path to config.yaml. Uses default resolution if None.
|
||||
checkpointer: LangGraph checkpointer instance for state persistence.
|
||||
Required for multi-turn conversations on the same thread_id.
|
||||
Without a checkpointer, each call is stateless.
|
||||
model_name: Override the default model name from config.
|
||||
thinking_enabled: Enable model's extended thinking.
|
||||
subagent_enabled: Enable subagent delegation.
|
||||
plan_mode: Enable TodoList middleware for plan mode.
|
||||
"""
|
||||
if config_path is not None:
|
||||
reload_app_config(config_path)
|
||||
self._app_config = get_app_config()
|
||||
|
||||
self._checkpointer = checkpointer
|
||||
self._model_name = model_name
|
||||
self._thinking_enabled = thinking_enabled
|
||||
self._subagent_enabled = subagent_enabled
|
||||
self._plan_mode = plan_mode
|
||||
|
||||
# Lazy agent — created on first call, recreated when config changes.
|
||||
self._agent = None
|
||||
self._agent_config_key: tuple | None = None
|
||||
|
||||
def reset_agent(self) -> None:
|
||||
"""Force the internal agent to be recreated on the next call.
|
||||
|
||||
Use this after external changes (e.g. memory updates, skill
|
||||
installations) that should be reflected in the system prompt
|
||||
or tool set.
|
||||
"""
|
||||
self._agent = None
|
||||
self._agent_config_key = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _atomic_write_json(path: Path, data: dict) -> None:
|
||||
"""Write JSON to *path* atomically (temp file + replace)."""
|
||||
fd = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
dir=path.parent,
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
)
|
||||
try:
|
||||
json.dump(data, fd, indent=2)
|
||||
fd.close()
|
||||
Path(fd.name).replace(path)
|
||||
except BaseException:
|
||||
fd.close()
|
||||
Path(fd.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
def _get_runnable_config(self, thread_id: str, **overrides) -> RunnableConfig:
|
||||
"""Build a RunnableConfig for agent invocation."""
|
||||
configurable = {
|
||||
"thread_id": thread_id,
|
||||
"model_name": overrides.get("model_name", self._model_name),
|
||||
"thinking_enabled": overrides.get("thinking_enabled", self._thinking_enabled),
|
||||
"is_plan_mode": overrides.get("plan_mode", self._plan_mode),
|
||||
"subagent_enabled": overrides.get("subagent_enabled", self._subagent_enabled),
|
||||
}
|
||||
return RunnableConfig(
|
||||
configurable=configurable,
|
||||
recursion_limit=overrides.get("recursion_limit", 100),
|
||||
)
|
||||
|
||||
def _ensure_agent(self, config: RunnableConfig):
|
||||
"""Create (or recreate) the agent when config-dependent params change."""
|
||||
cfg = config.get("configurable", {})
|
||||
key = (
|
||||
cfg.get("model_name"),
|
||||
cfg.get("thinking_enabled"),
|
||||
cfg.get("is_plan_mode"),
|
||||
cfg.get("subagent_enabled"),
|
||||
)
|
||||
|
||||
if self._agent is not None and self._agent_config_key == key:
|
||||
return
|
||||
|
||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||
model_name = cfg.get("model_name")
|
||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||
"middleware": _build_middlewares(config, model_name=model_name),
|
||||
"system_prompt": apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
),
|
||||
"state_schema": ThreadState,
|
||||
}
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
if checkpointer is not None:
|
||||
kwargs["checkpointer"] = checkpointer
|
||||
|
||||
self._agent = create_agent(**kwargs)
|
||||
self._agent_config_key = key
|
||||
logger.info("Agent created: model=%s, thinking=%s", model_name, thinking_enabled)
|
||||
|
||||
@staticmethod
|
||||
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
|
||||
"""Lazy import to avoid circular dependency at module level."""
|
||||
from deerflow.tools import get_available_tools
|
||||
|
||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_message(msg) -> dict:
|
||||
"""Serialize a LangChain message to a plain dict for values events."""
|
||||
if isinstance(msg, AIMessage):
|
||||
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if msg.tool_calls:
|
||||
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls]
|
||||
return d
|
||||
if isinstance(msg, ToolMessage):
|
||||
return {
|
||||
"type": "tool",
|
||||
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": getattr(msg, "id", None),
|
||||
}
|
||||
if isinstance(msg, HumanMessage):
|
||||
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if isinstance(msg, SystemMessage):
|
||||
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(content) -> str:
|
||||
"""Extract plain text from AIMessage content (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
parts.append(block["text"])
|
||||
return "\n".join(parts) if parts else ""
|
||||
return str(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — conversation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def stream(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Stream a conversation turn, yielding events incrementally.
|
||||
|
||||
Each call sends one user message and yields events until the agent
|
||||
finishes its turn. A ``checkpointer`` must be provided at init time
|
||||
for multi-turn context to be preserved across calls.
|
||||
|
||||
Event types align with the LangGraph SSE protocol so that
|
||||
consumers can switch between HTTP streaming and embedded mode
|
||||
without changing their event-handling logic.
|
||||
|
||||
Args:
|
||||
message: User message text.
|
||||
thread_id: Thread ID for conversation context. Auto-generated if None.
|
||||
**kwargs: Override client defaults (model_name, thinking_enabled,
|
||||
plan_mode, subagent_enabled, recursion_limit).
|
||||
|
||||
Yields:
|
||||
StreamEvent with one of:
|
||||
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||
- type="end" data={}
|
||||
"""
|
||||
if thread_id is None:
|
||||
thread_id = str(uuid.uuid4())
|
||||
|
||||
config = self._get_runnable_config(thread_id, **kwargs)
|
||||
self._ensure_agent(config)
|
||||
|
||||
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
|
||||
context = {"thread_id": thread_id}
|
||||
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"):
|
||||
messages = chunk.get("messages", [])
|
||||
|
||||
for msg in messages:
|
||||
msg_id = getattr(msg, "id", None)
|
||||
if msg_id and msg_id in seen_ids:
|
||||
continue
|
||||
if msg_id:
|
||||
seen_ids.add(msg_id)
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
if msg.tool_calls:
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
text = self._extract_text(msg.content)
|
||||
if text:
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={"type": "ai", "content": text, "id": msg_id},
|
||||
)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "tool",
|
||||
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": msg_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Emit a values event for each state snapshot
|
||||
yield StreamEvent(
|
||||
type="values",
|
||||
data={
|
||||
"title": chunk.get("title"),
|
||||
"messages": [self._serialize_message(m) for m in messages],
|
||||
"artifacts": chunk.get("artifacts", []),
|
||||
},
|
||||
)
|
||||
|
||||
yield StreamEvent(type="end", data={})
|
||||
|
||||
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
|
||||
"""Send a message and return the final text response.
|
||||
|
||||
Convenience wrapper around :meth:`stream` that returns only the
|
||||
**last** AI text from ``messages-tuple`` events. If the agent emits
|
||||
multiple text segments in one turn, intermediate segments are
|
||||
discarded. Use :meth:`stream` directly to capture all events.
|
||||
|
||||
Args:
|
||||
message: User message text.
|
||||
thread_id: Thread ID for conversation context. Auto-generated if None.
|
||||
**kwargs: Override client defaults (same as stream()).
|
||||
|
||||
Returns:
|
||||
The last AI message text, or empty string if no response.
|
||||
"""
|
||||
last_text = ""
|
||||
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
||||
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
||||
content = event.data.get("content", "")
|
||||
if content:
|
||||
last_text = content
|
||||
return last_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — configuration queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_models(self) -> dict:
|
||||
"""List available models from configuration.
|
||||
|
||||
Returns:
|
||||
Dict with "models" key containing list of model info dicts,
|
||||
matching the Gateway API ``ModelsListResponse`` schema.
|
||||
"""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"name": model.name,
|
||||
"display_name": getattr(model, "display_name", None),
|
||||
"description": getattr(model, "description", None),
|
||||
"supports_thinking": getattr(model, "supports_thinking", False),
|
||||
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
|
||||
}
|
||||
for model in self._app_config.models
|
||||
]
|
||||
}
|
||||
|
||||
def list_skills(self, enabled_only: bool = False) -> dict:
|
||||
"""List available skills.
|
||||
|
||||
Args:
|
||||
enabled_only: If True, only return enabled skills.
|
||||
|
||||
Returns:
|
||||
Dict with "skills" key containing list of skill info dicts,
|
||||
matching the Gateway API ``SkillsListResponse`` schema.
|
||||
"""
|
||||
from deerflow.skills.loader import load_skills
|
||||
|
||||
return {
|
||||
"skills": [
|
||||
{
|
||||
"name": s.name,
|
||||
"description": s.description,
|
||||
"license": s.license,
|
||||
"category": s.category,
|
||||
"enabled": s.enabled,
|
||||
}
|
||||
for s in load_skills(enabled_only=enabled_only)
|
||||
]
|
||||
}
|
||||
|
||||
def get_memory(self) -> dict:
|
||||
"""Get current memory data.
|
||||
|
||||
Returns:
|
||||
Memory data dict (see src/agents/memory/updater.py for structure).
|
||||
"""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data()
|
||||
|
||||
def get_model(self, name: str) -> dict | None:
|
||||
"""Get a specific model's configuration by name.
|
||||
|
||||
Args:
|
||||
name: Model name.
|
||||
|
||||
Returns:
|
||||
Model info dict matching the Gateway API ``ModelResponse``
|
||||
schema, or None if not found.
|
||||
"""
|
||||
model = self._app_config.get_model_config(name)
|
||||
if model is None:
|
||||
return None
|
||||
return {
|
||||
"name": model.name,
|
||||
"display_name": getattr(model, "display_name", None),
|
||||
"description": getattr(model, "description", None),
|
||||
"supports_thinking": getattr(model, "supports_thinking", False),
|
||||
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — MCP configuration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_mcp_config(self) -> dict:
|
||||
"""Get MCP server configurations.
|
||||
|
||||
Returns:
|
||||
Dict with "mcp_servers" key mapping server name to config,
|
||||
matching the Gateway API ``McpConfigResponse`` schema.
|
||||
"""
|
||||
config = get_extensions_config()
|
||||
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
|
||||
|
||||
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
|
||||
"""Update MCP server configurations.
|
||||
|
||||
Writes to extensions_config.json and reloads the cache.
|
||||
|
||||
Args:
|
||||
mcp_servers: Dict mapping server name to config dict.
|
||||
Each value should contain keys like enabled, type, command, args, env, url, etc.
|
||||
|
||||
Returns:
|
||||
Dict with "mcp_servers" key, matching the Gateway API
|
||||
``McpConfigResponse`` schema.
|
||||
|
||||
Raises:
|
||||
OSError: If the config file cannot be written.
|
||||
"""
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
if config_path is None:
|
||||
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
|
||||
|
||||
current_config = get_extensions_config()
|
||||
|
||||
config_data = {
|
||||
"mcpServers": mcp_servers,
|
||||
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
|
||||
}
|
||||
|
||||
self._atomic_write_json(config_path, config_data)
|
||||
|
||||
self._agent = None
|
||||
reloaded = reload_extensions_config()
|
||||
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — skills management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_skill(self, name: str) -> dict | None:
|
||||
"""Get a specific skill by name.
|
||||
|
||||
Args:
|
||||
name: Skill name.
|
||||
|
||||
Returns:
|
||||
Skill info dict, or None if not found.
|
||||
"""
|
||||
from deerflow.skills.loader import load_skills
|
||||
|
||||
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
|
||||
if skill is None:
|
||||
return None
|
||||
return {
|
||||
"name": skill.name,
|
||||
"description": skill.description,
|
||||
"license": skill.license,
|
||||
"category": skill.category,
|
||||
"enabled": skill.enabled,
|
||||
}
|
||||
|
||||
def update_skill(self, name: str, *, enabled: bool) -> dict:
|
||||
"""Update a skill's enabled status.
|
||||
|
||||
Args:
|
||||
name: Skill name.
|
||||
enabled: New enabled status.
|
||||
|
||||
Returns:
|
||||
Updated skill info dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If the skill is not found.
|
||||
OSError: If the config file cannot be written.
|
||||
"""
|
||||
from deerflow.skills.loader import load_skills
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == name), None)
|
||||
if skill is None:
|
||||
raise ValueError(f"Skill '{name}' not found")
|
||||
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
if config_path is None:
|
||||
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
|
||||
|
||||
extensions_config = get_extensions_config()
|
||||
extensions_config.skills[name] = SkillStateConfig(enabled=enabled)
|
||||
|
||||
config_data = {
|
||||
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()},
|
||||
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()},
|
||||
}
|
||||
|
||||
self._atomic_write_json(config_path, config_data)
|
||||
|
||||
self._agent = None
|
||||
reload_extensions_config()
|
||||
|
||||
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
|
||||
if updated is None:
|
||||
raise RuntimeError(f"Skill '{name}' disappeared after update")
|
||||
return {
|
||||
"name": updated.name,
|
||||
"description": updated.description,
|
||||
"license": updated.license,
|
||||
"category": updated.category,
|
||||
"enabled": updated.enabled,
|
||||
}
|
||||
|
||||
def install_skill(self, skill_path: str | Path) -> dict:
|
||||
"""Install a skill from a .skill archive (ZIP).
|
||||
|
||||
Args:
|
||||
skill_path: Path to the .skill file.
|
||||
|
||||
Returns:
|
||||
Dict with success, skill_name, message.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
ValueError: If the file is invalid.
|
||||
"""
|
||||
from deerflow.skills.loader import get_skills_root_path
|
||||
from deerflow.skills.validation import _validate_skill_frontmatter
|
||||
|
||||
path = Path(skill_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Skill file not found: {skill_path}")
|
||||
if not path.is_file():
|
||||
raise ValueError(f"Path is not a file: {skill_path}")
|
||||
if path.suffix != ".skill":
|
||||
raise ValueError("File must have .skill extension")
|
||||
if not zipfile.is_zipfile(path):
|
||||
raise ValueError("File is not a valid ZIP archive")
|
||||
|
||||
skills_root = get_skills_root_path()
|
||||
custom_dir = skills_root / "custom"
|
||||
custom_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
with zipfile.ZipFile(path, "r") as zf:
|
||||
total_size = sum(info.file_size for info in zf.infolist())
|
||||
if total_size > 100 * 1024 * 1024:
|
||||
raise ValueError("Skill archive too large when extracted (>100MB)")
|
||||
for info in zf.infolist():
|
||||
if Path(info.filename).is_absolute() or ".." in Path(info.filename).parts:
|
||||
raise ValueError(f"Unsafe path in archive: {info.filename}")
|
||||
zf.extractall(tmp_path)
|
||||
for p in tmp_path.rglob("*"):
|
||||
if p.is_symlink():
|
||||
p.unlink()
|
||||
|
||||
items = list(tmp_path.iterdir())
|
||||
if not items:
|
||||
raise ValueError("Skill archive is empty")
|
||||
|
||||
skill_dir = items[0] if len(items) == 1 and items[0].is_dir() else tmp_path
|
||||
|
||||
is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid skill: {message}")
|
||||
if not re.fullmatch(r"[a-zA-Z0-9_-]+", skill_name):
|
||||
raise ValueError(f"Invalid skill name: {skill_name}")
|
||||
|
||||
target = custom_dir / skill_name
|
||||
if target.exists():
|
||||
raise ValueError(f"Skill '{skill_name}' already exists")
|
||||
|
||||
shutil.copytree(skill_dir, target)
|
||||
|
||||
return {"success": True, "skill_name": skill_name, "message": f"Skill '{skill_name}' installed successfully"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — memory management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reload_memory(self) -> dict:
|
||||
"""Reload memory data from file, forcing cache invalidation.
|
||||
|
||||
Returns:
|
||||
The reloaded memory data dict.
|
||||
"""
|
||||
from deerflow.agents.memory.updater import reload_memory_data
|
||||
|
||||
return reload_memory_data()
|
||||
|
||||
def get_memory_config(self) -> dict:
|
||||
"""Get memory system configuration.
|
||||
|
||||
Returns:
|
||||
Memory config dict.
|
||||
"""
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
config = get_memory_config()
|
||||
return {
|
||||
"enabled": config.enabled,
|
||||
"storage_path": config.storage_path,
|
||||
"debounce_seconds": config.debounce_seconds,
|
||||
"max_facts": config.max_facts,
|
||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||
"injection_enabled": config.injection_enabled,
|
||||
"max_injection_tokens": config.max_injection_tokens,
|
||||
}
|
||||
|
||||
def get_memory_status(self) -> dict:
|
||||
"""Get memory status: config + current data.
|
||||
|
||||
Returns:
|
||||
Dict with "config" and "data" keys.
|
||||
"""
|
||||
return {
|
||||
"config": self.get_memory_config(),
|
||||
"data": self.get_memory(),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — file uploads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _get_uploads_dir(thread_id: str) -> Path:
|
||||
"""Get (and create) the uploads directory for a thread."""
|
||||
base = get_paths().sandbox_uploads_dir(thread_id)
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base
|
||||
|
||||
def upload_files(self, thread_id: str, files: list[str | Path]) -> dict:
|
||||
"""Upload local files into a thread's uploads directory.
|
||||
|
||||
For PDF, PPT, Excel, and Word files, they are also converted to Markdown.
|
||||
|
||||
Args:
|
||||
thread_id: Target thread ID.
|
||||
files: List of local file paths to upload.
|
||||
|
||||
Returns:
|
||||
Dict with success, files, message — matching the Gateway API
|
||||
``UploadResponse`` schema.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If any file does not exist.
|
||||
ValueError: If any supplied path exists but is not a regular file.
|
||||
"""
|
||||
from deerflow.utils.file_conversion import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
|
||||
|
||||
# Validate all files upfront to avoid partial uploads.
|
||||
resolved_files = []
|
||||
convertible_extensions = {ext.lower() for ext in CONVERTIBLE_EXTENSIONS}
|
||||
has_convertible_file = False
|
||||
for f in files:
|
||||
p = Path(f)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"File not found: {f}")
|
||||
if not p.is_file():
|
||||
raise ValueError(f"Path is not a file: {f}")
|
||||
resolved_files.append(p)
|
||||
if not has_convertible_file and p.suffix.lower() in convertible_extensions:
|
||||
has_convertible_file = True
|
||||
|
||||
uploads_dir = self._get_uploads_dir(thread_id)
|
||||
uploaded_files: list[dict] = []
|
||||
|
||||
conversion_pool = None
|
||||
if has_convertible_file:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
conversion_pool = None
|
||||
else:
|
||||
import concurrent.futures
|
||||
|
||||
# Reuse one worker when already inside an event loop to avoid
|
||||
# creating a new ThreadPoolExecutor per converted file.
|
||||
conversion_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _convert_in_thread(path: Path):
|
||||
return asyncio.run(convert_file_to_markdown(path))
|
||||
|
||||
try:
|
||||
for src_path in resolved_files:
|
||||
dest = uploads_dir / src_path.name
|
||||
shutil.copy2(src_path, dest)
|
||||
|
||||
info: dict[str, Any] = {
|
||||
"filename": src_path.name,
|
||||
"size": str(dest.stat().st_size),
|
||||
"path": str(dest),
|
||||
"virtual_path": f"/mnt/user-data/uploads/{src_path.name}",
|
||||
"artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{src_path.name}",
|
||||
}
|
||||
|
||||
if src_path.suffix.lower() in convertible_extensions:
|
||||
try:
|
||||
if conversion_pool is not None:
|
||||
md_path = conversion_pool.submit(_convert_in_thread, dest).result()
|
||||
else:
|
||||
md_path = asyncio.run(convert_file_to_markdown(dest))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to convert %s to markdown",
|
||||
src_path.name,
|
||||
exc_info=True,
|
||||
)
|
||||
md_path = None
|
||||
|
||||
if md_path is not None:
|
||||
info["markdown_file"] = md_path.name
|
||||
info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}"
|
||||
info["markdown_artifact_url"] = f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{md_path.name}"
|
||||
|
||||
uploaded_files.append(info)
|
||||
finally:
|
||||
if conversion_pool is not None:
|
||||
conversion_pool.shutdown(wait=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"files": uploaded_files,
|
||||
"message": f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||
}
|
||||
|
||||
def list_uploads(self, thread_id: str) -> dict:
|
||||
"""List files in a thread's uploads directory.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
|
||||
Returns:
|
||||
Dict with "files" and "count" keys, matching the Gateway API
|
||||
``list_uploaded_files`` response.
|
||||
"""
|
||||
uploads_dir = self._get_uploads_dir(thread_id)
|
||||
if not uploads_dir.exists():
|
||||
return {"files": [], "count": 0}
|
||||
|
||||
files = []
|
||||
with os.scandir(uploads_dir) as entries:
|
||||
file_entries = [entry for entry in entries if entry.is_file()]
|
||||
|
||||
for entry in sorted(file_entries, key=lambda item: item.name):
|
||||
stat = entry.stat()
|
||||
filename = entry.name
|
||||
files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": str(stat.st_size),
|
||||
"path": str(Path(entry.path)),
|
||||
"virtual_path": f"/mnt/user-data/uploads/{filename}",
|
||||
"artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{filename}",
|
||||
"extension": Path(filename).suffix,
|
||||
"modified": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
return {"files": files, "count": len(files)}
|
||||
|
||||
def delete_upload(self, thread_id: str, filename: str) -> dict:
|
||||
"""Delete a file from a thread's uploads directory.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
filename: Filename to delete.
|
||||
|
||||
Returns:
|
||||
Dict with success and message, matching the Gateway API
|
||||
``delete_uploaded_file`` response.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
PermissionError: If path traversal is detected.
|
||||
"""
|
||||
uploads_dir = self._get_uploads_dir(thread_id)
|
||||
file_path = (uploads_dir / filename).resolve()
|
||||
|
||||
try:
|
||||
file_path.relative_to(uploads_dir.resolve())
|
||||
except ValueError as exc:
|
||||
raise PermissionError("Access denied: path traversal detected") from exc
|
||||
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
file_path.unlink()
|
||||
return {"success": True, "message": f"Deleted {filename}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — artifacts
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_artifact(self, thread_id: str, path: str) -> tuple[bytes, str]:
|
||||
"""Read an artifact file produced by the agent.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
path: Virtual path (e.g. "mnt/user-data/outputs/file.txt").
|
||||
|
||||
Returns:
|
||||
Tuple of (file_bytes, mime_type).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the artifact does not exist.
|
||||
ValueError: If the path is invalid.
|
||||
"""
|
||||
virtual_prefix = "mnt/user-data"
|
||||
clean_path = path.lstrip("/")
|
||||
if not clean_path.startswith(virtual_prefix):
|
||||
raise ValueError(f"Path must start with /{virtual_prefix}")
|
||||
|
||||
relative = clean_path[len(virtual_prefix) :].lstrip("/")
|
||||
base_dir = get_paths().sandbox_user_data_dir(thread_id)
|
||||
actual = (base_dir / relative).resolve()
|
||||
|
||||
try:
|
||||
actual.relative_to(base_dir.resolve())
|
||||
except ValueError as exc:
|
||||
raise PermissionError("Access denied: path traversal detected") from exc
|
||||
if not actual.exists():
|
||||
raise FileNotFoundError(f"Artifact not found: {path}")
|
||||
if not actual.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(actual)
|
||||
return actual.read_bytes(), mime_type or "application/octet-stream"
|
||||
@@ -0,0 +1,15 @@
|
||||
from .aio_sandbox import AioSandbox
|
||||
from .aio_sandbox_provider import AioSandboxProvider
|
||||
from .backend import SandboxBackend
|
||||
from .local_backend import LocalContainerBackend
|
||||
from .remote_backend import RemoteSandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
__all__ = [
|
||||
"AioSandbox",
|
||||
"AioSandboxProvider",
|
||||
"LocalContainerBackend",
|
||||
"RemoteSandboxBackend",
|
||||
"SandboxBackend",
|
||||
"SandboxInfo",
|
||||
]
|
||||
@@ -0,0 +1,128 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from agent_sandbox import Sandbox as AioSandboxClient
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AioSandbox(Sandbox):
|
||||
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
||||
|
||||
This sandbox connects to a running AIO sandbox container via HTTP API.
|
||||
"""
|
||||
|
||||
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
||||
"""Initialize the AIO sandbox.
|
||||
|
||||
Args:
|
||||
id: Unique identifier for this sandbox instance.
|
||||
base_url: URL of the sandbox API (e.g., http://localhost:8080).
|
||||
home_dir: Home directory inside the sandbox. If None, will be fetched from the sandbox.
|
||||
"""
|
||||
super().__init__(id)
|
||||
self._base_url = base_url
|
||||
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
||||
self._home_dir = home_dir
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
||||
@property
|
||||
def home_dir(self) -> str:
|
||||
"""Get the home directory inside the sandbox."""
|
||||
if self._home_dir is None:
|
||||
context = self._client.sandbox.get_context()
|
||||
self._home_dir = context.home_dir
|
||||
return self._home_dir
|
||||
|
||||
def execute_command(self, command: str) -> str:
|
||||
"""Execute a shell command in the sandbox.
|
||||
|
||||
Args:
|
||||
command: The command to execute.
|
||||
|
||||
Returns:
|
||||
The output of the command.
|
||||
"""
|
||||
try:
|
||||
result = self._client.shell.exec_command(command=command)
|
||||
output = result.data.output if result.data else ""
|
||||
return output if output else "(no output)"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute command in sandbox: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
"""Read the content of a file in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to read.
|
||||
|
||||
Returns:
|
||||
The content of the file.
|
||||
"""
|
||||
try:
|
||||
result = self._client.file.read_file(file=path)
|
||||
return result.data.content if result.data else ""
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read file in sandbox: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
||||
"""List the contents of a directory in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the directory to list.
|
||||
max_depth: The maximum depth to traverse. Default is 2.
|
||||
|
||||
Returns:
|
||||
The contents of the directory.
|
||||
"""
|
||||
try:
|
||||
# Use shell command to list directory with depth limit
|
||||
# The -L flag limits the depth for the tree command
|
||||
result = self._client.shell.exec_command(command=f"find {path} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
||||
output = result.data.output if result.data else ""
|
||||
if output:
|
||||
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list directory in sandbox: {e}")
|
||||
return []
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
"""Write content to a file in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to write to.
|
||||
content: The text content to write to the file.
|
||||
append: Whether to append the content to the file.
|
||||
"""
|
||||
try:
|
||||
if append:
|
||||
# Read existing content first and append
|
||||
existing = self.read_file(path)
|
||||
if not existing.startswith("Error:"):
|
||||
content = existing + content
|
||||
self._client.file.write_file(file=path, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to update.
|
||||
content: The binary content to write to the file.
|
||||
"""
|
||||
try:
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file in sandbox: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,609 @@
|
||||
"""AIO Sandbox Provider — orchestrates sandbox lifecycle with pluggable backends.
|
||||
|
||||
This provider composes:
|
||||
- SandboxBackend: how sandboxes are provisioned (local container vs remote/K8s)
|
||||
|
||||
The provider itself handles:
|
||||
- In-process caching for fast repeated access
|
||||
- Idle timeout management
|
||||
- Graceful shutdown with signal handling
|
||||
- Mount computation (thread-specific, skills)
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import fcntl
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, Paths, get_paths
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
from .aio_sandbox import AioSandbox
|
||||
from .backend import SandboxBackend, wait_for_sandbox_ready
|
||||
from .local_backend import LocalContainerBackend
|
||||
from .remote_backend import RemoteSandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
|
||||
DEFAULT_PORT = 8080
|
||||
DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox"
|
||||
DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds
|
||||
DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers
|
||||
IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds
|
||||
|
||||
|
||||
class AioSandboxProvider(SandboxProvider):
|
||||
"""Sandbox provider that manages containers running the AIO sandbox.
|
||||
|
||||
Architecture:
|
||||
This provider composes a SandboxBackend (how to provision), enabling:
|
||||
- Local Docker/Apple Container mode (auto-start containers)
|
||||
- Remote/K8s mode (connect to pre-existing sandbox URL)
|
||||
|
||||
Configuration options in config.yaml under sandbox:
|
||||
use: deerflow.community.aio_sandbox:AioSandboxProvider
|
||||
image: <container image>
|
||||
port: 8080 # Base port for local containers
|
||||
container_prefix: deer-flow-sandbox
|
||||
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
|
||||
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
|
||||
mounts: # Volume mounts for local containers
|
||||
- host_path: /path/on/host
|
||||
container_path: /path/in/container
|
||||
read_only: false
|
||||
environment: # Environment variables for containers
|
||||
NODE_ENV: production
|
||||
API_KEY: $MY_API_KEY
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
|
||||
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
|
||||
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id
|
||||
self._thread_locks: dict[str, threading.Lock] = {} # thread_id -> in-process lock
|
||||
self._last_activity: dict[str, float] = {} # sandbox_id -> last activity timestamp
|
||||
# Warm pool: released sandboxes whose containers are still running.
|
||||
# Maps sandbox_id -> (SandboxInfo, release_timestamp).
|
||||
# Containers here can be reclaimed quickly (no cold-start) or destroyed
|
||||
# when replicas capacity is exhausted.
|
||||
self._warm_pool: dict[str, tuple[SandboxInfo, float]] = {}
|
||||
self._shutdown_called = False
|
||||
self._idle_checker_stop = threading.Event()
|
||||
self._idle_checker_thread: threading.Thread | None = None
|
||||
|
||||
self._config = self._load_config()
|
||||
self._backend: SandboxBackend = self._create_backend()
|
||||
|
||||
# Register shutdown handler
|
||||
atexit.register(self.shutdown)
|
||||
self._register_signal_handlers()
|
||||
|
||||
# Start idle checker if enabled
|
||||
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
||||
self._start_idle_checker()
|
||||
|
||||
# ── Factory methods ──────────────────────────────────────────────────
|
||||
|
||||
def _create_backend(self) -> SandboxBackend:
|
||||
"""Create the appropriate backend based on configuration.
|
||||
|
||||
Selection logic (checked in order):
|
||||
1. ``provisioner_url`` set → RemoteSandboxBackend (provisioner mode)
|
||||
Provisioner dynamically creates Pods + Services in k3s.
|
||||
2. Default → LocalContainerBackend (local mode)
|
||||
Local provider manages container lifecycle directly (start/stop).
|
||||
"""
|
||||
provisioner_url = self._config.get("provisioner_url")
|
||||
if provisioner_url:
|
||||
logger.info(f"Using remote sandbox backend with provisioner at {provisioner_url}")
|
||||
return RemoteSandboxBackend(provisioner_url=provisioner_url)
|
||||
|
||||
logger.info("Using local container sandbox backend")
|
||||
return LocalContainerBackend(
|
||||
image=self._config["image"],
|
||||
base_port=self._config["port"],
|
||||
container_prefix=self._config["container_prefix"],
|
||||
config_mounts=self._config["mounts"],
|
||||
environment=self._config["environment"],
|
||||
)
|
||||
|
||||
# ── Configuration ────────────────────────────────────────────────────
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""Load sandbox configuration from app config."""
|
||||
config = get_app_config()
|
||||
sandbox_config = config.sandbox
|
||||
|
||||
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
|
||||
replicas = getattr(sandbox_config, "replicas", None)
|
||||
|
||||
return {
|
||||
"image": sandbox_config.image or DEFAULT_IMAGE,
|
||||
"port": sandbox_config.port or DEFAULT_PORT,
|
||||
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
|
||||
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
|
||||
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
|
||||
"mounts": sandbox_config.mounts or [],
|
||||
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
|
||||
# provisioner URL for dynamic pod management (e.g. http://provisioner:8002)
|
||||
"provisioner_url": getattr(sandbox_config, "provisioner_url", None) or "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_env_vars(env_config: dict[str, str]) -> dict[str, str]:
|
||||
"""Resolve environment variable references (values starting with $)."""
|
||||
resolved = {}
|
||||
for key, value in env_config.items():
|
||||
if isinstance(value, str) and value.startswith("$"):
|
||||
env_name = value[1:]
|
||||
resolved[key] = os.environ.get(env_name, "")
|
||||
else:
|
||||
resolved[key] = str(value)
|
||||
return resolved
|
||||
|
||||
# ── Deterministic ID ─────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _deterministic_sandbox_id(thread_id: str) -> str:
|
||||
"""Generate a deterministic sandbox ID from a thread ID.
|
||||
|
||||
Ensures all processes derive the same sandbox_id for a given thread,
|
||||
enabling cross-process sandbox discovery without shared memory.
|
||||
"""
|
||||
return hashlib.sha256(thread_id.encode()).hexdigest()[:8]
|
||||
|
||||
# ── Mount helpers ────────────────────────────────────────────────────
|
||||
|
||||
def _get_extra_mounts(self, thread_id: str | None) -> list[tuple[str, str, bool]]:
|
||||
"""Collect all extra mounts for a sandbox (thread-specific + skills)."""
|
||||
mounts: list[tuple[str, str, bool]] = []
|
||||
|
||||
if thread_id:
|
||||
mounts.extend(self._get_thread_mounts(thread_id))
|
||||
logger.info(f"Adding thread mounts for thread {thread_id}: {mounts}")
|
||||
|
||||
skills_mount = self._get_skills_mount()
|
||||
if skills_mount:
|
||||
mounts.append(skills_mount)
|
||||
logger.info(f"Adding skills mount: {skills_mount}")
|
||||
|
||||
return mounts
|
||||
|
||||
@staticmethod
|
||||
def _get_thread_mounts(thread_id: str) -> list[tuple[str, str, bool]]:
|
||||
"""Get volume mounts for a thread's data directories.
|
||||
|
||||
Creates directories if they don't exist (lazy initialization).
|
||||
Mount sources use host_base_dir so that when running inside Docker with a
|
||||
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
||||
"""
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
|
||||
# host_paths resolves to the host-side base dir when DEER_FLOW_HOST_BASE_DIR
|
||||
# is set, otherwise falls back to the container's own base dir (native mode).
|
||||
host_paths = Paths(base_dir=paths.host_base_dir)
|
||||
|
||||
return [
|
||||
(str(host_paths.sandbox_work_dir(thread_id)), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||
(str(host_paths.sandbox_uploads_dir(thread_id)), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||
(str(host_paths.sandbox_outputs_dir(thread_id)), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_skills_mount() -> tuple[str, str, bool] | None:
|
||||
"""Get the skills directory mount configuration.
|
||||
|
||||
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
|
||||
so the host Docker daemon can resolve the path.
|
||||
"""
|
||||
try:
|
||||
config = get_app_config()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
container_path = config.skills.container_path
|
||||
|
||||
if skills_path.exists():
|
||||
# When running inside Docker with DooD, use host-side skills path.
|
||||
host_skills = os.environ.get("DEER_FLOW_HOST_SKILLS_PATH") or str(skills_path)
|
||||
return (host_skills, container_path, True) # Read-only for security
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not setup skills mount: {e}")
|
||||
return None
|
||||
|
||||
# ── Idle timeout management ──────────────────────────────────────────
|
||||
|
||||
def _start_idle_checker(self) -> None:
|
||||
"""Start the background thread that checks for idle sandboxes."""
|
||||
self._idle_checker_thread = threading.Thread(
|
||||
target=self._idle_checker_loop,
|
||||
name="sandbox-idle-checker",
|
||||
daemon=True,
|
||||
)
|
||||
self._idle_checker_thread.start()
|
||||
logger.info(f"Started idle checker thread (timeout: {self._config.get('idle_timeout', DEFAULT_IDLE_TIMEOUT)}s)")
|
||||
|
||||
def _idle_checker_loop(self) -> None:
|
||||
idle_timeout = self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT)
|
||||
while not self._idle_checker_stop.wait(timeout=IDLE_CHECK_INTERVAL):
|
||||
try:
|
||||
self._cleanup_idle_sandboxes(idle_timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in idle checker loop: {e}")
|
||||
|
||||
def _cleanup_idle_sandboxes(self, idle_timeout: float) -> None:
|
||||
current_time = time.time()
|
||||
active_to_destroy = []
|
||||
warm_to_destroy: list[tuple[str, SandboxInfo]] = []
|
||||
|
||||
with self._lock:
|
||||
# Active sandboxes: tracked via _last_activity
|
||||
for sandbox_id, last_activity in self._last_activity.items():
|
||||
idle_duration = current_time - last_activity
|
||||
if idle_duration > idle_timeout:
|
||||
active_to_destroy.append(sandbox_id)
|
||||
logger.info(f"Sandbox {sandbox_id} idle for {idle_duration:.1f}s, marking for destroy")
|
||||
|
||||
# Warm pool: tracked via release_timestamp stored in _warm_pool
|
||||
for sandbox_id, (info, release_ts) in list(self._warm_pool.items()):
|
||||
warm_duration = current_time - release_ts
|
||||
if warm_duration > idle_timeout:
|
||||
warm_to_destroy.append((sandbox_id, info))
|
||||
del self._warm_pool[sandbox_id]
|
||||
logger.info(f"Warm-pool sandbox {sandbox_id} idle for {warm_duration:.1f}s, marking for destroy")
|
||||
|
||||
# Destroy active sandboxes (re-verify still idle before acting)
|
||||
for sandbox_id in active_to_destroy:
|
||||
try:
|
||||
# Re-verify the sandbox is still idle under the lock before destroying.
|
||||
# Between the snapshot above and here, the sandbox may have been
|
||||
# re-acquired (last_activity updated) or already released/destroyed.
|
||||
with self._lock:
|
||||
last_activity = self._last_activity.get(sandbox_id)
|
||||
if last_activity is None:
|
||||
# Already released or destroyed by another path — skip.
|
||||
logger.info(f"Sandbox {sandbox_id} already gone before idle destroy, skipping")
|
||||
continue
|
||||
if (time.time() - last_activity) < idle_timeout:
|
||||
# Re-acquired (activity updated) since the snapshot — skip.
|
||||
logger.info(f"Sandbox {sandbox_id} was re-acquired before idle destroy, skipping")
|
||||
continue
|
||||
logger.info(f"Destroying idle sandbox {sandbox_id}")
|
||||
self.destroy(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy idle sandbox {sandbox_id}: {e}")
|
||||
|
||||
# Destroy warm-pool sandboxes (already removed from _warm_pool under lock above)
|
||||
for sandbox_id, info in warm_to_destroy:
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed idle warm-pool sandbox {sandbox_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy idle warm-pool sandbox {sandbox_id}: {e}")
|
||||
|
||||
# ── Signal handling ──────────────────────────────────────────────────
|
||||
|
||||
def _register_signal_handlers(self) -> None:
|
||||
"""Register signal handlers for graceful shutdown."""
|
||||
self._original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
self._original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
self.shutdown()
|
||||
original = self._original_sigterm if signum == signal.SIGTERM else self._original_sigint
|
||||
if callable(original):
|
||||
original(signum, frame)
|
||||
elif original == signal.SIG_DFL:
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
signal.raise_signal(signum)
|
||||
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
except ValueError:
|
||||
logger.debug("Could not register signal handlers (not main thread)")
|
||||
|
||||
# ── Thread locking (in-process) ──────────────────────────────────────
|
||||
|
||||
def _get_thread_lock(self, thread_id: str) -> threading.Lock:
|
||||
"""Get or create an in-process lock for a specific thread_id."""
|
||||
with self._lock:
|
||||
if thread_id not in self._thread_locks:
|
||||
self._thread_locks[thread_id] = threading.Lock()
|
||||
return self._thread_locks[thread_id]
|
||||
|
||||
# ── Core: acquire / get / release / shutdown ─────────────────────────
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
For the same thread_id, this method will return the same sandbox_id
|
||||
across multiple turns, multiple processes, and (with shared storage)
|
||||
multiple pods.
|
||||
|
||||
Thread-safe with both in-process and cross-process locking.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID for thread-specific configurations.
|
||||
|
||||
Returns:
|
||||
The ID of the acquired sandbox environment.
|
||||
"""
|
||||
if thread_id:
|
||||
thread_lock = self._get_thread_lock(thread_id)
|
||||
with thread_lock:
|
||||
return self._acquire_internal(thread_id)
|
||||
else:
|
||||
return self._acquire_internal(thread_id)
|
||||
|
||||
def _acquire_internal(self, thread_id: str | None) -> str:
|
||||
"""Internal sandbox acquisition with two-layer consistency.
|
||||
|
||||
Layer 1: In-process cache (fastest, covers same-process repeated access)
|
||||
Layer 2: Backend discovery (covers containers started by other processes;
|
||||
sandbox_id is deterministic from thread_id so no shared state file
|
||||
is needed — any process can derive the same container name)
|
||||
"""
|
||||
# ── Layer 1: In-process cache (fast path) ──
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
else:
|
||||
del self._thread_sandboxes[thread_id]
|
||||
|
||||
# Deterministic ID for thread-specific, random for anonymous
|
||||
sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]
|
||||
|
||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
|
||||
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
|
||||
# Use a file lock so that two processes racing to create the same sandbox
|
||||
# for the same thread_id serialize here: the second process will discover
|
||||
# the container started by the first instead of hitting a name-conflict.
|
||||
if thread_id:
|
||||
return self._discover_or_create_with_lock(thread_id, sandbox_id)
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
|
||||
def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str:
|
||||
"""Discover an existing sandbox or create a new one under a cross-process file lock.
|
||||
|
||||
The file lock serializes concurrent sandbox creation for the same thread_id
|
||||
across multiple processes, preventing container-name conflicts.
|
||||
"""
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
||||
|
||||
with open(lock_path, "a") as lock_file:
|
||||
try:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
||||
# Re-check in-process caches under the file lock in case another
|
||||
# thread in this process won the race while we were waiting.
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
if sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)")
|
||||
return sandbox_id
|
||||
|
||||
# Backend discovery: another process may have created the container.
|
||||
discovered = self._backend.discover(sandbox_id)
|
||||
if discovered is not None:
|
||||
sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[discovered.sandbox_id] = sandbox
|
||||
self._sandbox_infos[discovered.sandbox_id] = discovered
|
||||
self._last_activity[discovered.sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = discovered.sandbox_id
|
||||
logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}")
|
||||
return discovered.sandbox_id
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
finally:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
||||
|
||||
def _evict_oldest_warm(self) -> str | None:
|
||||
"""Destroy the oldest container in the warm pool to free capacity.
|
||||
|
||||
Returns:
|
||||
The evicted sandbox_id, or None if warm pool is empty.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._warm_pool:
|
||||
return None
|
||||
oldest_id = min(self._warm_pool, key=lambda sid: self._warm_pool[sid][1])
|
||||
info, _ = self._warm_pool.pop(oldest_id)
|
||||
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed warm-pool sandbox {oldest_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy warm-pool sandbox {oldest_id}: {e}")
|
||||
return None
|
||||
return oldest_id
|
||||
|
||||
def _create_sandbox(self, thread_id: str | None, sandbox_id: str) -> str:
|
||||
"""Create a new sandbox via the backend.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID.
|
||||
sandbox_id: The sandbox ID to use.
|
||||
|
||||
Returns:
|
||||
The sandbox_id.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox creation or readiness check fails.
|
||||
"""
|
||||
extra_mounts = self._get_extra_mounts(thread_id)
|
||||
|
||||
# Enforce replicas: only warm-pool containers count toward eviction budget.
|
||||
# Active sandboxes are in use by live threads and must not be forcibly stopped.
|
||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||
with self._lock:
|
||||
total = len(self._sandboxes) + len(self._warm_pool)
|
||||
if total >= replicas:
|
||||
evicted = self._evict_oldest_warm()
|
||||
if evicted:
|
||||
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
|
||||
else:
|
||||
# All slots are occupied by active sandboxes — proceed anyway and log.
|
||||
# The replicas limit is a soft cap; we never forcibly stop a container
|
||||
# that is actively serving a thread.
|
||||
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
|
||||
|
||||
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None)
|
||||
|
||||
# Wait for sandbox to be ready
|
||||
if not wait_for_sandbox_ready(info.sandbox_url, timeout=60):
|
||||
self._backend.destroy(info)
|
||||
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
|
||||
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if thread_id:
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox by ID. Updates last activity timestamp.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox.
|
||||
|
||||
Returns:
|
||||
The sandbox instance if found, None otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
sandbox = self._sandboxes.get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
return sandbox
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox from active use into the warm pool.
|
||||
|
||||
The container is kept running so it can be reclaimed quickly by the same
|
||||
thread on its next turn without a cold-start. The container will only be
|
||||
stopped when the replicas limit forces eviction or during shutdown.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox to release.
|
||||
"""
|
||||
info = None
|
||||
thread_ids_to_remove: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
# Park in warm pool — container keeps running
|
||||
if info and sandbox_id not in self._warm_pool:
|
||||
self._warm_pool[sandbox_id] = (info, time.time())
|
||||
|
||||
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
|
||||
|
||||
def destroy(self, sandbox_id: str) -> None:
|
||||
"""Destroy a sandbox: stop the container and free all resources.
|
||||
|
||||
Unlike release(), this actually stops the container. Use this for
|
||||
explicit cleanup, capacity-driven eviction, or shutdown.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox to destroy.
|
||||
"""
|
||||
info = None
|
||||
thread_ids_to_remove: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
# Also pull from warm pool if it was parked there
|
||||
if info is None and sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
else:
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
|
||||
if info:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed sandbox {sandbox_id}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown all sandboxes. Thread-safe and idempotent."""
|
||||
with self._lock:
|
||||
if self._shutdown_called:
|
||||
return
|
||||
self._shutdown_called = True
|
||||
sandbox_ids = list(self._sandboxes.keys())
|
||||
warm_items = list(self._warm_pool.items())
|
||||
self._warm_pool.clear()
|
||||
|
||||
# Stop idle checker
|
||||
self._idle_checker_stop.set()
|
||||
if self._idle_checker_thread is not None and self._idle_checker_thread.is_alive():
|
||||
self._idle_checker_thread.join(timeout=5)
|
||||
logger.info("Stopped idle checker thread")
|
||||
|
||||
logger.info(f"Shutting down {len(sandbox_ids)} active + {len(warm_items)} warm-pool sandbox(es)")
|
||||
|
||||
for sandbox_id in sandbox_ids:
|
||||
try:
|
||||
self.destroy(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy sandbox {sandbox_id} during shutdown: {e}")
|
||||
|
||||
for sandbox_id, (info, _) in warm_items:
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed warm-pool sandbox {sandbox_id} during shutdown")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy warm-pool sandbox {sandbox_id} during shutdown: {e}")
|
||||
@@ -0,0 +1,98 @@
|
||||
"""Abstract base class for sandbox provisioning backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import requests
|
||||
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool:
|
||||
"""Poll sandbox health endpoint until ready or timeout.
|
||||
|
||||
Args:
|
||||
sandbox_url: URL of the sandbox (e.g. http://k3s:30001).
|
||||
timeout: Maximum time to wait in seconds.
|
||||
|
||||
Returns:
|
||||
True if sandbox is ready, False otherwise.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{sandbox_url}/v1/sandbox", timeout=5)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
class SandboxBackend(ABC):
|
||||
"""Abstract base for sandbox provisioning backends.
|
||||
|
||||
Two implementations:
|
||||
- LocalContainerBackend: starts Docker/Apple Container locally, manages ports
|
||||
- RemoteSandboxBackend: connects to a pre-existing URL (K8s service, external)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""Create/provision a new sandbox.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID for which the sandbox is being created. Useful for backends that want to organize sandboxes by thread.
|
||||
sandbox_id: Deterministic sandbox identifier.
|
||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||
Ignored by backends that don't manage containers (e.g., remote).
|
||||
|
||||
Returns:
|
||||
SandboxInfo with connection details.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Destroy/cleanup a sandbox and release its resources.
|
||||
|
||||
Args:
|
||||
info: The sandbox metadata to destroy.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def is_alive(self, info: SandboxInfo) -> bool:
|
||||
"""Quick check whether a sandbox is still alive.
|
||||
|
||||
This should be a lightweight check (e.g., container inspect)
|
||||
rather than a full health check.
|
||||
|
||||
Args:
|
||||
info: The sandbox metadata to check.
|
||||
|
||||
Returns:
|
||||
True if the sandbox appears to be alive.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Try to discover an existing sandbox by its deterministic ID.
|
||||
|
||||
Used for cross-process recovery: when another process started a sandbox,
|
||||
this process can discover it by the deterministic container name or URL.
|
||||
|
||||
Args:
|
||||
sandbox_id: The deterministic sandbox ID to look for.
|
||||
|
||||
Returns:
|
||||
SandboxInfo if found and healthy, None otherwise.
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,327 @@
|
||||
"""Local container backend for sandbox provisioning.
|
||||
|
||||
Manages sandbox containers using Docker or Apple Container on the local machine.
|
||||
Handles container lifecycle, port allocation, and cross-process container discovery.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from deerflow.utils.network import get_free_port, release_port
|
||||
|
||||
from .backend import SandboxBackend, wait_for_sandbox_ready
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalContainerBackend(SandboxBackend):
|
||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||
|
||||
On macOS, automatically prefers Apple Container if available, otherwise falls back to Docker.
|
||||
On other platforms, uses Docker.
|
||||
|
||||
Features:
|
||||
- Deterministic container naming for cross-process discovery
|
||||
- Port allocation with thread-safe utilities
|
||||
- Container lifecycle management (start/stop with --rm)
|
||||
- Support for volume mounts and environment variables
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image: str,
|
||||
base_port: int,
|
||||
container_prefix: str,
|
||||
config_mounts: list,
|
||||
environment: dict[str, str],
|
||||
):
|
||||
"""Initialize the local container backend.
|
||||
|
||||
Args:
|
||||
image: Container image to use.
|
||||
base_port: Base port number to start searching for free ports.
|
||||
container_prefix: Prefix for container names (e.g., "deer-flow-sandbox").
|
||||
config_mounts: Volume mount configurations from config (list of VolumeMountConfig).
|
||||
environment: Environment variables to inject into containers.
|
||||
"""
|
||||
self._image = image
|
||||
self._base_port = base_port
|
||||
self._container_prefix = container_prefix
|
||||
self._config_mounts = config_mounts
|
||||
self._environment = environment
|
||||
self._runtime = self._detect_runtime()
|
||||
|
||||
@property
|
||||
def runtime(self) -> str:
|
||||
"""The detected container runtime ("docker" or "container")."""
|
||||
return self._runtime
|
||||
|
||||
def _detect_runtime(self) -> str:
|
||||
"""Detect which container runtime to use.
|
||||
|
||||
On macOS, prefer Apple Container if available, otherwise fall back to Docker.
|
||||
On other platforms, use Docker.
|
||||
|
||||
Returns:
|
||||
"container" for Apple Container, "docker" for Docker.
|
||||
"""
|
||||
import platform
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["container", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
logger.info(f"Detected Apple Container: {result.stdout.strip()}")
|
||||
return "container"
|
||||
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||
logger.info("Apple Container not available, falling back to Docker")
|
||||
|
||||
return "docker"
|
||||
|
||||
# ── SandboxBackend interface ──────────────────────────────────────────
|
||||
|
||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""Start a new container and return its connection info.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID for which the sandbox is being created. Useful for backends that want to organize sandboxes by thread.
|
||||
sandbox_id: Deterministic sandbox identifier (used in container name).
|
||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||
|
||||
Returns:
|
||||
SandboxInfo with container details.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the container fails to start.
|
||||
"""
|
||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||
|
||||
# Retry loop: if Docker rejects the port (e.g. a stale container still
|
||||
# holds the binding after a process restart), skip that port and try the
|
||||
# next one. The socket-bind check in get_free_port mirrors Docker's
|
||||
# 0.0.0.0 bind, but Docker's port-release can be slightly asynchronous,
|
||||
# so a reactive fallback here ensures we always make progress.
|
||||
_next_start = self._base_port
|
||||
container_id: str | None = None
|
||||
port: int = 0
|
||||
for _attempt in range(10):
|
||||
port = get_free_port(start_port=_next_start)
|
||||
try:
|
||||
container_id = self._start_container(container_name, port, extra_mounts)
|
||||
break
|
||||
except RuntimeError as exc:
|
||||
release_port(port)
|
||||
err = str(exc)
|
||||
err_lower = err.lower()
|
||||
# Port already bound: skip this port and retry with the next one.
|
||||
if "port is already allocated" in err or "address already in use" in err_lower:
|
||||
logger.warning(f"Port {port} rejected by Docker (already allocated), retrying with next port")
|
||||
_next_start = port + 1
|
||||
continue
|
||||
# Container-name conflict: another process may have already started
|
||||
# the deterministic sandbox container for this sandbox_id. Try to
|
||||
# discover and adopt the existing container instead of failing.
|
||||
if "is already in use by container" in err_lower or "conflict. the container name" in err_lower:
|
||||
logger.warning(f"Container name {container_name} already in use, attempting to discover existing sandbox instance")
|
||||
existing = self.discover(sandbox_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("Could not start sandbox container: all candidate ports are already allocated by Docker")
|
||||
|
||||
# When running inside Docker (DooD), sandbox containers are reachable via
|
||||
# host.docker.internal rather than localhost (they run on the host daemon).
|
||||
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=f"http://{sandbox_host}:{port}",
|
||||
container_name=container_name,
|
||||
container_id=container_id,
|
||||
)
|
||||
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Stop the container and release its port."""
|
||||
if info.container_id:
|
||||
self._stop_container(info.container_id)
|
||||
# Extract port from sandbox_url for release
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
port = urlparse(info.sandbox_url).port
|
||||
if port:
|
||||
release_port(port)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def is_alive(self, info: SandboxInfo) -> bool:
|
||||
"""Check if the container is still running (lightweight, no HTTP)."""
|
||||
if info.container_name:
|
||||
return self._is_container_running(info.container_name)
|
||||
return False
|
||||
|
||||
def discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Discover an existing container by its deterministic name.
|
||||
|
||||
Checks if a container with the expected name is running, retrieves its
|
||||
port, and verifies it responds to health checks.
|
||||
|
||||
Args:
|
||||
sandbox_id: The deterministic sandbox ID (determines container name).
|
||||
|
||||
Returns:
|
||||
SandboxInfo if container found and healthy, None otherwise.
|
||||
"""
|
||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||
|
||||
if not self._is_container_running(container_name):
|
||||
return None
|
||||
|
||||
port = self._get_container_port(container_name)
|
||||
if port is None:
|
||||
return None
|
||||
|
||||
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
||||
sandbox_url = f"http://{sandbox_host}:{port}"
|
||||
if not wait_for_sandbox_ready(sandbox_url, timeout=5):
|
||||
return None
|
||||
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=sandbox_url,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
# ── Container operations ─────────────────────────────────────────────
|
||||
|
||||
def _start_container(
|
||||
self,
|
||||
container_name: str,
|
||||
port: int,
|
||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||
) -> str:
|
||||
"""Start a new container.
|
||||
|
||||
Args:
|
||||
container_name: Name for the container.
|
||||
port: Host port to map to container port 8080.
|
||||
extra_mounts: Additional volume mounts.
|
||||
|
||||
Returns:
|
||||
The container ID.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If container fails to start.
|
||||
"""
|
||||
cmd = [self._runtime, "run"]
|
||||
|
||||
# Docker-specific security options
|
||||
if self._runtime == "docker":
|
||||
cmd.extend(["--security-opt", "seccomp=unconfined"])
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"--rm",
|
||||
"-d",
|
||||
"-p",
|
||||
f"{port}:8080",
|
||||
"--name",
|
||||
container_name,
|
||||
]
|
||||
)
|
||||
|
||||
# Environment variables
|
||||
for key, value in self._environment.items():
|
||||
cmd.extend(["-e", f"{key}={value}"])
|
||||
|
||||
# Config-level volume mounts
|
||||
for mount in self._config_mounts:
|
||||
mount_spec = f"{mount.host_path}:{mount.container_path}"
|
||||
if mount.read_only:
|
||||
mount_spec += ":ro"
|
||||
cmd.extend(["-v", mount_spec])
|
||||
|
||||
# Extra mounts (thread-specific, skills, etc.)
|
||||
if extra_mounts:
|
||||
for host_path, container_path, read_only in extra_mounts:
|
||||
mount_spec = f"{host_path}:{container_path}"
|
||||
if read_only:
|
||||
mount_spec += ":ro"
|
||||
cmd.extend(["-v", mount_spec])
|
||||
|
||||
cmd.append(self._image)
|
||||
|
||||
logger.info(f"Starting container using {self._runtime}: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
container_id = result.stdout.strip()
|
||||
logger.info(f"Started container {container_name} (ID: {container_id}) using {self._runtime}")
|
||||
return container_id
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to start container using {self._runtime}: {e.stderr}")
|
||||
raise RuntimeError(f"Failed to start sandbox container: {e.stderr}")
|
||||
|
||||
def _stop_container(self, container_id: str) -> None:
|
||||
"""Stop a container (--rm ensures automatic removal)."""
|
||||
try:
|
||||
subprocess.run(
|
||||
[self._runtime, "stop", container_id],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
logger.info(f"Stopped container {container_id} using {self._runtime}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Failed to stop container {container_id}: {e.stderr}")
|
||||
|
||||
def _is_container_running(self, container_name: str) -> bool:
|
||||
"""Check if a named container is currently running.
|
||||
|
||||
This enables cross-process container discovery — any process can detect
|
||||
containers started by another process via the deterministic container name.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self._runtime, "inspect", "-f", "{{.State.Running}}", container_name],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||
return False
|
||||
|
||||
def _get_container_port(self, container_name: str) -> int | None:
|
||||
"""Get the host port of a running container.
|
||||
|
||||
Args:
|
||||
container_name: The container name to inspect.
|
||||
|
||||
Returns:
|
||||
The host port mapped to container port 8080, or None if not found.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self._runtime, "port", container_name, "8080"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
# Output format: "0.0.0.0:PORT" or ":::PORT"
|
||||
port_str = result.stdout.strip().split(":")[-1]
|
||||
return int(port_str)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError):
|
||||
pass
|
||||
return None
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Remote sandbox backend — delegates Pod lifecycle to the provisioner service.
|
||||
|
||||
The provisioner dynamically creates per-sandbox-id Pods + NodePort Services
|
||||
in k3s. The backend accesses sandbox pods directly via ``k3s:{NodePort}``.
|
||||
|
||||
Architecture:
|
||||
┌────────────┐ HTTP ┌─────────────┐ K8s API ┌──────────┐
|
||||
│ this file │ ──────▸ │ provisioner │ ────────▸ │ k3s │
|
||||
│ (backend) │ │ :8002 │ │ :6443 │
|
||||
└────────────┘ └─────────────┘ └─────┬────┘
|
||||
│ creates
|
||||
┌─────────────┐ ┌─────▼──────┐
|
||||
│ backend │ ────────▸ │ sandbox │
|
||||
│ │ direct │ Pod(s) │
|
||||
└─────────────┘ k3s:NPort └────────────┘
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from .backend import SandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteSandboxBackend(SandboxBackend):
|
||||
"""Backend that delegates sandbox lifecycle to the provisioner service.
|
||||
|
||||
All Pod creation, destruction, and discovery are handled by the
|
||||
provisioner. This backend is a thin HTTP client.
|
||||
|
||||
Typical config.yaml::
|
||||
|
||||
sandbox:
|
||||
use: deerflow.community.aio_sandbox:AioSandboxProvider
|
||||
provisioner_url: http://provisioner:8002
|
||||
"""
|
||||
|
||||
def __init__(self, provisioner_url: str):
|
||||
"""Initialize with the provisioner service URL.
|
||||
|
||||
Args:
|
||||
provisioner_url: URL of the provisioner service
|
||||
(e.g., ``http://provisioner:8002``).
|
||||
"""
|
||||
self._provisioner_url = provisioner_url.rstrip("/")
|
||||
|
||||
@property
|
||||
def provisioner_url(self) -> str:
|
||||
return self._provisioner_url
|
||||
|
||||
# ── SandboxBackend interface ──────────────────────────────────────────
|
||||
|
||||
def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
sandbox_id: str,
|
||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||
) -> SandboxInfo:
|
||||
"""Create a sandbox Pod + Service via the provisioner.
|
||||
|
||||
Calls ``POST /api/sandboxes`` which creates a dedicated Pod +
|
||||
NodePort Service in k3s.
|
||||
"""
|
||||
return self._provisioner_create(thread_id, sandbox_id, extra_mounts)
|
||||
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Destroy a sandbox Pod + Service via the provisioner."""
|
||||
self._provisioner_destroy(info.sandbox_id)
|
||||
|
||||
def is_alive(self, info: SandboxInfo) -> bool:
|
||||
"""Check whether the sandbox Pod is running."""
|
||||
return self._provisioner_is_alive(info.sandbox_id)
|
||||
|
||||
def discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Discover an existing sandbox via the provisioner.
|
||||
|
||||
Calls ``GET /api/sandboxes/{sandbox_id}`` and returns info if
|
||||
the Pod exists.
|
||||
"""
|
||||
return self._provisioner_discover(sandbox_id)
|
||||
|
||||
# ── Provisioner API calls ─────────────────────────────────────────────
|
||||
|
||||
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""POST /api/sandboxes → create Pod + Service."""
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{self._provisioner_url}/api/sandboxes",
|
||||
json={
|
||||
"sandbox_id": sandbox_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
logger.info(f"Provisioner created sandbox {sandbox_id}: sandbox_url={data['sandbox_url']}")
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=data["sandbox_url"],
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
logger.error(f"Provisioner create failed for {sandbox_id}: {exc}")
|
||||
raise RuntimeError(f"Provisioner create failed: {exc}") from exc
|
||||
|
||||
def _provisioner_destroy(self, sandbox_id: str) -> None:
|
||||
"""DELETE /api/sandboxes/{sandbox_id} → destroy Pod + Service."""
|
||||
try:
|
||||
resp = requests.delete(
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=15,
|
||||
)
|
||||
if resp.ok:
|
||||
logger.info(f"Provisioner destroyed sandbox {sandbox_id}")
|
||||
else:
|
||||
logger.warning(f"Provisioner destroy returned {resp.status_code}: {resp.text}")
|
||||
except requests.RequestException as exc:
|
||||
logger.warning(f"Provisioner destroy failed for {sandbox_id}: {exc}")
|
||||
|
||||
def _provisioner_is_alive(self, sandbox_id: str) -> bool:
|
||||
"""GET /api/sandboxes/{sandbox_id} → check Pod phase."""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.ok:
|
||||
data = resp.json()
|
||||
return data.get("status") == "Running"
|
||||
return False
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=data["sandbox_url"],
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
logger.debug(f"Provisioner discover failed for {sandbox_id}: {exc}")
|
||||
return None
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Sandbox metadata for cross-process discovery and state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxInfo:
|
||||
"""Persisted sandbox metadata that enables cross-process discovery.
|
||||
|
||||
This dataclass holds all the information needed to reconnect to an
|
||||
existing sandbox from a different process (e.g., gateway vs langgraph,
|
||||
multiple workers, or across K8s pods with shared storage).
|
||||
"""
|
||||
|
||||
sandbox_id: str
|
||||
sandbox_url: str # e.g. http://localhost:8080 or http://k3s:30001
|
||||
container_name: str | None = None # Only for local container backend
|
||||
container_id: str | None = None # Only for local container backend
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"sandbox_id": self.sandbox_id,
|
||||
"sandbox_url": self.sandbox_url,
|
||||
"container_name": self.container_name,
|
||||
"container_id": self.container_id,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> SandboxInfo:
|
||||
return cls(
|
||||
sandbox_id=data["sandbox_id"],
|
||||
sandbox_url=data.get("sandbox_url", data.get("base_url", "")),
|
||||
container_name=data.get("container_name"),
|
||||
container_id=data.get("container_id"),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
import json
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
|
||||
def _get_firecrawl_client() -> FirecrawlApp:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
api_key = None
|
||||
if config is not None:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
try:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
max_results = 5
|
||||
if config is not None:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
|
||||
client = _get_firecrawl_client()
|
||||
result = client.search(query, limit=max_results)
|
||||
|
||||
# result.web contains list of SearchResultWeb objects
|
||||
web_results = result.web or []
|
||||
normalized_results = [
|
||||
{
|
||||
"title": getattr(item, "title", "") or "",
|
||||
"url": getattr(item, "url", "") or "",
|
||||
"snippet": getattr(item, "description", "") or "",
|
||||
}
|
||||
for item in web_results
|
||||
]
|
||||
json_results = json.dumps(normalized_results, indent=2, ensure_ascii=False)
|
||||
return json_results
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
Do NOT add www. to URLs that do NOT have them.
|
||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
try:
|
||||
client = _get_firecrawl_client()
|
||||
result = client.scrape(url, formats=["markdown"])
|
||||
|
||||
markdown_content = result.markdown or ""
|
||||
metadata = result.metadata
|
||||
title = metadata.title if metadata and metadata.title else "Untitled"
|
||||
|
||||
if not markdown_content:
|
||||
return "Error: No content found"
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
return f"# {title}\n\n{markdown_content[:4096]}"
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import image_search_tool
|
||||
|
||||
__all__ = ["image_search_tool"]
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Image Search Tool - Search images using DuckDuckGo for reference in image generation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _search_images(
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
size: str | None = None,
|
||||
color: str | None = None,
|
||||
type_image: str | None = None,
|
||||
layout: str | None = None,
|
||||
license_image: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Execute image search using DuckDuckGo.
|
||||
|
||||
Args:
|
||||
query: Search keywords
|
||||
max_results: Maximum number of results
|
||||
region: Search region
|
||||
safesearch: Safe search level
|
||||
size: Image size (Small/Medium/Large/Wallpaper)
|
||||
color: Color filter
|
||||
type_image: Image type (photo/clipart/gif/transparent/line)
|
||||
layout: Layout (Square/Tall/Wide)
|
||||
license_image: License filter
|
||||
|
||||
Returns:
|
||||
List of search results
|
||||
"""
|
||||
try:
|
||||
from ddgs import DDGS
|
||||
except ImportError:
|
||||
logger.error("ddgs library not installed. Run: pip install ddgs")
|
||||
return []
|
||||
|
||||
ddgs = DDGS(timeout=30)
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"region": region,
|
||||
"safesearch": safesearch,
|
||||
"max_results": max_results,
|
||||
}
|
||||
|
||||
if size:
|
||||
kwargs["size"] = size
|
||||
if color:
|
||||
kwargs["color"] = color
|
||||
if type_image:
|
||||
kwargs["type_image"] = type_image
|
||||
if layout:
|
||||
kwargs["layout"] = layout
|
||||
if license_image:
|
||||
kwargs["license_image"] = license_image
|
||||
|
||||
results = ddgs.images(query, **kwargs)
|
||||
return list(results) if results else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search images: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@tool("image_search", parse_docstring=True)
|
||||
def image_search_tool(
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
size: str | None = None,
|
||||
type_image: str | None = None,
|
||||
layout: str | None = None,
|
||||
) -> str:
|
||||
"""Search for images online. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
|
||||
|
||||
**When to use:**
|
||||
- Before generating character/portrait images: search for similar poses, expressions, styles
|
||||
- Before generating specific objects/products: search for accurate visual references
|
||||
- Before generating scenes/locations: search for architectural or environmental references
|
||||
- Before generating fashion/clothing: search for style and detail references
|
||||
|
||||
The returned image URLs can be used as reference images in image generation to significantly improve quality.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing the images you want to find. Be specific for better results (e.g., "Japanese woman street photography 1990s" instead of just "woman").
|
||||
max_results: Maximum number of images to return. Default is 5.
|
||||
size: Image size filter. Options: "Small", "Medium", "Large", "Wallpaper". Use "Large" for reference images.
|
||||
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
|
||||
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
|
||||
"""
|
||||
config = get_app_config().get_tool_config("image_search")
|
||||
|
||||
# Override max_results from config if set
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
|
||||
results = _search_images(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
size=size,
|
||||
type_image=type_image,
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return json.dumps({"error": "No images found", "query": query}, ensure_ascii=False)
|
||||
|
||||
normalized_results = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"image_url": r.get("thumbnail", ""),
|
||||
"thumbnail_url": r.get("thumbnail", ""),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"total_results": len(normalized_results),
|
||||
"results": normalized_results,
|
||||
"usage_hint": "Use the 'image_url' values as reference images in image generation. Download them first if needed.",
|
||||
}
|
||||
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
@@ -0,0 +1,311 @@
|
||||
"""Util that calls InfoQuest Search And Fetch API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://docs.byteplus.com/en/docs/InfoQuest/What_is_Info_Quest
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InfoQuestClient:
|
||||
"""Client for interacting with the InfoQuest web search and fetch API."""
|
||||
|
||||
def __init__(self, fetch_time: int = -1, fetch_timeout: int = -1, fetch_navigation_timeout: int = -1, search_time_range: int = -1):
|
||||
logger.info("\n============================================\n🚀 BytePlus InfoQuest Client Initialization 🚀\n============================================")
|
||||
|
||||
self.fetch_time = fetch_time
|
||||
self.fetch_timeout = fetch_timeout
|
||||
self.fetch_navigation_timeout = fetch_navigation_timeout
|
||||
self.search_time_range = search_time_range
|
||||
self.api_key_set = bool(os.getenv("INFOQUEST_API_KEY"))
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
config_details = (
|
||||
f"\n📋 Configuration Details:\n"
|
||||
f"├── Fetch time: {fetch_time} {'(Default: No fetch time)' if fetch_time == -1 else '(Custom)'}\n"
|
||||
f"├── Fetch Timeout: {fetch_timeout} {'(Default: No fetch timeout)' if fetch_timeout == -1 else '(Custom)'}\n"
|
||||
f"├── Navigation Timeout: {fetch_navigation_timeout} {'(Default: No Navigation Timeout)' if fetch_navigation_timeout == -1 else '(Custom)'}\n"
|
||||
f"├── Search Time Range: {search_time_range} {'(Default: No Search Time Range)' if search_time_range == -1 else '(Custom)'}\n"
|
||||
f"└── API Key: {'✅ Configured' if self.api_key_set else '❌ Not set'}"
|
||||
)
|
||||
|
||||
logger.debug(config_details)
|
||||
logger.debug("\n" + "*" * 70 + "\n")
|
||||
|
||||
def fetch(self, url: str, return_format: str = "html") -> str:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
url_truncated = url[:50] + "..." if len(url) > 50 else url
|
||||
logger.debug(
|
||||
f"InfoQuest - Fetch API request initiated | "
|
||||
f"operation=crawl url | "
|
||||
f"url_truncated={url_truncated} | "
|
||||
f"has_timeout_filter={self.fetch_timeout > 0} | timeout_filter={self.fetch_timeout} | "
|
||||
f"has_fetch_time_filter={self.fetch_time > 0} | fetch_time_filter={self.fetch_time} | "
|
||||
f"has_navigation_timeout_filter={self.fetch_navigation_timeout > 0} | navi_timeout_filter={self.fetch_navigation_timeout} | "
|
||||
f"request_type=sync"
|
||||
)
|
||||
|
||||
# Prepare headers
|
||||
headers = self._prepare_headers()
|
||||
|
||||
# Prepare request data
|
||||
data = self._prepare_crawl_request_data(url, return_format)
|
||||
|
||||
logger.debug("Sending crawl request to InfoQuest API")
|
||||
try:
|
||||
response = requests.post("https://reader.infoquest.bytepluses.com", headers=headers, json=data)
|
||||
|
||||
# Check if status code is not 200
|
||||
if response.status_code != 200:
|
||||
error_message = f"fetch API returned status {response.status_code}: {response.text}"
|
||||
logger.debug("InfoQuest Crawler fetch API return status %d: %s for URL: %s", response.status_code, response.text, url)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
# Check for empty response
|
||||
if not response.text or not response.text.strip():
|
||||
error_message = "no result found"
|
||||
logger.debug("InfoQuest Crawler returned empty response for URL: %s", url)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
# Try to parse response as JSON and extract reader_result
|
||||
try:
|
||||
response_data = json.loads(response.text)
|
||||
# Extract reader_result if it exists
|
||||
if "reader_result" in response_data:
|
||||
logger.debug("Successfully extracted reader_result from JSON response")
|
||||
return response_data["reader_result"]
|
||||
elif "content" in response_data:
|
||||
# Fallback to content field if reader_result is not available
|
||||
logger.debug("reader_result missing in JSON response, falling back to content field: %s", response_data["content"])
|
||||
return response_data["content"]
|
||||
else:
|
||||
# If neither field exists, return the original response
|
||||
logger.warning("Neither reader_result nor content field found in JSON response")
|
||||
except json.JSONDecodeError:
|
||||
# If response is not JSON, return the original text
|
||||
logger.debug("Response is not in JSON format, returning as-is")
|
||||
return response.text
|
||||
|
||||
# Print partial response for debugging
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
response_sample = response.text[:200] + ("..." if len(response.text) > 200 else "")
|
||||
logger.debug("Successfully received response, content length: %d bytes, first 200 chars: %s", len(response.text), response_sample)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
error_message = f"fetch API failed: {str(e)}"
|
||||
logger.error(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
@staticmethod
|
||||
def _prepare_headers() -> dict[str, str]:
|
||||
"""Prepare request headers."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add API key if available
|
||||
if os.getenv("INFOQUEST_API_KEY"):
|
||||
headers["Authorization"] = f"Bearer {os.getenv('INFOQUEST_API_KEY')}"
|
||||
logger.debug("API key added to request headers")
|
||||
else:
|
||||
logger.warning("InfoQuest API key is not set. Provide your own key for authentication.")
|
||||
|
||||
return headers
|
||||
|
||||
def _prepare_crawl_request_data(self, url: str, return_format: str) -> dict[str, Any]:
|
||||
"""Prepare request data with formatted parameters."""
|
||||
# Normalize return_format
|
||||
if return_format and return_format.lower() == "html":
|
||||
normalized_format = "HTML"
|
||||
else:
|
||||
normalized_format = return_format
|
||||
|
||||
data = {"url": url, "format": normalized_format}
|
||||
|
||||
# Add timeout parameters if set to positive values
|
||||
timeout_params = {}
|
||||
if self.fetch_time > 0:
|
||||
timeout_params["fetch_time"] = self.fetch_time
|
||||
if self.fetch_timeout > 0:
|
||||
timeout_params["timeout"] = self.fetch_timeout
|
||||
if self.fetch_navigation_timeout > 0:
|
||||
timeout_params["navi_timeout"] = self.fetch_navigation_timeout
|
||||
|
||||
# Log applied timeout parameters
|
||||
if timeout_params:
|
||||
logger.debug("Applying timeout parameters: %s", timeout_params)
|
||||
data.update(timeout_params)
|
||||
|
||||
return data
|
||||
|
||||
def web_search_raw_results(
|
||||
self,
|
||||
query: str,
|
||||
site: str,
|
||||
output_format: str = "JSON",
|
||||
) -> dict:
|
||||
"""Get results from the InfoQuest Web-Search API synchronously."""
|
||||
headers = self._prepare_headers()
|
||||
|
||||
params = {"format": output_format, "query": query}
|
||||
if self.search_time_range > 0:
|
||||
params["time_range"] = self.search_time_range
|
||||
|
||||
if site != "":
|
||||
params["site"] = site
|
||||
|
||||
response = requests.post("https://search.infoquest.bytepluses.com", headers=headers, json=params)
|
||||
response.raise_for_status()
|
||||
|
||||
# Print partial response for debugging
|
||||
response_json = response.json()
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
response_sample = json.dumps(response_json)[:200] + ("..." if len(json.dumps(response_json)) > 200 else "")
|
||||
logger.debug(f"Search API request completed successfully | service=InfoQuest | status=success | response_sample={response_sample}")
|
||||
|
||||
return response_json
|
||||
|
||||
@staticmethod
|
||||
def clean_results(raw_results: list[dict[str, dict[str, dict[str, Any]]]]) -> list[dict]:
|
||||
"""Clean results from InfoQuest Web-Search API."""
|
||||
logger.debug("Processing web-search results")
|
||||
|
||||
seen_urls = set()
|
||||
clean_results = []
|
||||
counts = {"pages": 0, "news": 0}
|
||||
|
||||
for content_list in raw_results:
|
||||
content = content_list["content"]
|
||||
results = content["results"]
|
||||
|
||||
if results.get("organic"):
|
||||
organic_results = results["organic"]
|
||||
for result in organic_results:
|
||||
clean_result = {
|
||||
"type": "page",
|
||||
}
|
||||
if "title" in result:
|
||||
clean_result["title"] = result["title"]
|
||||
if "desc" in result:
|
||||
clean_result["desc"] = result["desc"]
|
||||
clean_result["snippet"] = result["desc"]
|
||||
if "url" in result:
|
||||
clean_result["url"] = result["url"]
|
||||
url = clean_result["url"]
|
||||
if isinstance(url, str) and url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
clean_results.append(clean_result)
|
||||
counts["pages"] += 1
|
||||
|
||||
if results.get("top_stories"):
|
||||
news = results["top_stories"]
|
||||
for obj in news["items"]:
|
||||
clean_result = {
|
||||
"type": "news",
|
||||
}
|
||||
if "time_frame" in obj:
|
||||
clean_result["time_frame"] = obj["time_frame"]
|
||||
if "source" in obj:
|
||||
clean_result["source"] = obj["source"]
|
||||
title = obj.get("title")
|
||||
url = obj.get("url")
|
||||
if title:
|
||||
clean_result["title"] = title
|
||||
if url:
|
||||
clean_result["url"] = url
|
||||
if title and isinstance(url, str) and url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
clean_results.append(clean_result)
|
||||
counts["news"] += 1
|
||||
logger.debug(f"Results processing completed | total_results={len(clean_results)} | pages={counts['pages']} | news_items={counts['news']} | unique_urls={len(seen_urls)}")
|
||||
|
||||
return clean_results
|
||||
|
||||
def web_search(
|
||||
self,
|
||||
query: str,
|
||||
site: str = "",
|
||||
output_format: str = "JSON",
|
||||
) -> str:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
query_truncated = query[:50] + "..." if len(query) > 50 else query
|
||||
logger.debug(
|
||||
f"InfoQuest - Search API request initiated | "
|
||||
f"operation=search webs | "
|
||||
f"query_truncated={query_truncated} | "
|
||||
f"has_time_filter={self.search_time_range > 0} | time_filter={self.search_time_range} | "
|
||||
f"has_site_filter={bool(site)} | site={site} | "
|
||||
f"request_type=sync"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug("InfoQuest Web-Search - Executing search with parameters")
|
||||
raw_results = self.web_search_raw_results(
|
||||
query,
|
||||
site,
|
||||
output_format,
|
||||
)
|
||||
if "search_result" in raw_results:
|
||||
logger.debug("InfoQuest Web-Search - Successfully extracted search_result from JSON response")
|
||||
results = raw_results["search_result"]
|
||||
|
||||
logger.debug("InfoQuest Web-Search - Processing raw search results")
|
||||
cleaned_results = self.clean_results(results["results"])
|
||||
|
||||
result_json = json.dumps(cleaned_results, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"InfoQuest Web-Search - Search tool execution completed | mode=synchronous | results_count={len(cleaned_results)}")
|
||||
return result_json
|
||||
|
||||
elif "content" in raw_results:
|
||||
# Fallback to content field if search_result is not available
|
||||
error_message = "web search API return wrong format"
|
||||
logger.error("web search API return wrong format, no search_result nor content field found in JSON response, content: %s", raw_results["content"])
|
||||
return f"Error: {error_message}"
|
||||
else:
|
||||
# If neither field exists, return the original response
|
||||
logger.warning("InfoQuest Web-Search - Neither search_result nor content field found in JSON response")
|
||||
return json.dumps(raw_results, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"InfoQuest Web-Search - Search tool execution failed | mode=synchronous | error={str(e)}"
|
||||
logger.error(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
@staticmethod
|
||||
def clean_results_with_image_search(raw_results: list[dict[str, dict[str, dict[str, Any]]]]) -> list[dict]:
|
||||
"""Clean results from InfoQuest Web-Search API."""
|
||||
logger.debug("Processing web-search results")
|
||||
|
||||
seen_urls = set()
|
||||
clean_results = []
|
||||
counts = {"images": 0}
|
||||
|
||||
for content_list in raw_results:
|
||||
content = content_list["content"]
|
||||
results = content["results"]
|
||||
|
||||
if results.get("images_results"):
|
||||
images_results = results["images_results"]
|
||||
for result in images_results:
|
||||
clean_result = {}
|
||||
if "image_url" in result:
|
||||
clean_result["image_url"] = result["image_url"]
|
||||
url = clean_result["image_url"]
|
||||
if isinstance(url, str) and url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
clean_results.append(clean_result)
|
||||
counts["images"] += 1
|
||||
if "thumbnail_url" in result:
|
||||
clean_result["thumbnail_url"] = result["thumbnail_url"]
|
||||
if "url" in result:
|
||||
clean_result["url"] = result["url"]
|
||||
logger.debug(f"Results processing completed | total_results={len(clean_results)} | images={counts['images']} | unique_urls={len(seen_urls)}")
|
||||
|
||||
return clean_results
|
||||
@@ -0,0 +1,63 @@
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.utils.readability import ReadabilityExtractor
|
||||
|
||||
from .infoquest_client import InfoQuestClient
|
||||
|
||||
readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
def _get_infoquest_client() -> InfoQuestClient:
|
||||
search_config = get_app_config().get_tool_config("web_search")
|
||||
search_time_range = -1
|
||||
if search_config is not None and "search_time_range" in search_config.model_extra:
|
||||
search_time_range = search_config.model_extra.get("search_time_range")
|
||||
fetch_config = get_app_config().get_tool_config("web_fetch")
|
||||
fetch_time = -1
|
||||
if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
|
||||
fetch_time = fetch_config.model_extra.get("fetch_time")
|
||||
fetch_timeout = -1
|
||||
if fetch_config is not None and "timeout" in fetch_config.model_extra:
|
||||
fetch_timeout = fetch_config.model_extra.get("timeout")
|
||||
navigation_timeout = -1
|
||||
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
|
||||
navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
|
||||
|
||||
return InfoQuestClient(
|
||||
search_time_range=search_time_range,
|
||||
fetch_timeout=fetch_timeout,
|
||||
fetch_navigation_timeout=navigation_timeout,
|
||||
fetch_time=fetch_time,
|
||||
)
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
|
||||
client = _get_infoquest_client()
|
||||
return client.web_search(query)
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
Do NOT add www. to URLs that do NOT have them.
|
||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
client = _get_infoquest_client()
|
||||
result = client.fetch(url)
|
||||
if result.startswith("Error: "):
|
||||
return result
|
||||
article = readability_extractor.extract_article(result)
|
||||
return article.to_markdown()[:4096]
|
||||
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JinaClient:
|
||||
def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Return-Format": return_format,
|
||||
"X-Timeout": str(timeout),
|
||||
}
|
||||
if os.getenv("JINA_API_KEY"):
|
||||
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
|
||||
else:
|
||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||
data = {"url": url}
|
||||
try:
|
||||
response = requests.post("https://r.jina.ai/", headers=headers, json=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"Jina API returned status {response.status_code}: {response.text}"
|
||||
logger.error(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
if not response.text or not response.text.strip():
|
||||
error_message = "Jina API returned empty response"
|
||||
logger.error(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
return response.text
|
||||
except Exception as e:
|
||||
error_message = f"Request to Jina API failed: {str(e)}"
|
||||
logger.error(error_message)
|
||||
return f"Error: {error_message}"
|
||||
28
backend/packages/harness/deerflow/community/jina_ai/tools.py
Normal file
28
backend/packages/harness/deerflow/community/jina_ai/tools.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.utils.readability import ReadabilityExtractor
|
||||
|
||||
readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
Do NOT add www. to URLs that do NOT have them.
|
||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
jina_client = JinaClient()
|
||||
timeout = 10
|
||||
config = get_app_config().get_tool_config("web_fetch")
|
||||
if config is not None and "timeout" in config.model_extra:
|
||||
timeout = config.model_extra.get("timeout")
|
||||
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
article = readability_extractor.extract_article(html_content)
|
||||
return article.to_markdown()[:4096]
|
||||
62
backend/packages/harness/deerflow/community/tavily/tools.py
Normal file
62
backend/packages/harness/deerflow/community/tavily/tools.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
|
||||
from langchain.tools import tool
|
||||
from tavily import TavilyClient
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
|
||||
def _get_tavily_client() -> TavilyClient:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
api_key = None
|
||||
if config is not None and "api_key" in config.model_extra:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
return TavilyClient(api_key=api_key)
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
max_results = 5
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results")
|
||||
|
||||
client = _get_tavily_client()
|
||||
res = client.search(query, max_results=max_results)
|
||||
normalized_results = [
|
||||
{
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
"snippet": result["content"],
|
||||
}
|
||||
for result in res["results"]
|
||||
]
|
||||
json_results = json.dumps(normalized_results, indent=2, ensure_ascii=False)
|
||||
return json_results
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
Do NOT add www. to URLs that do NOT have them.
|
||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
client = _get_tavily_client()
|
||||
res = client.extract([url])
|
||||
if "failed_results" in res and len(res["failed_results"]) > 0:
|
||||
return f"Error: {res['failed_results'][0]['error']}"
|
||||
elif "results" in res and len(res["results"]) > 0:
|
||||
result = res["results"][0]
|
||||
return f"# {result['title']}\n\n{result['raw_content'][:4096]}"
|
||||
else:
|
||||
return "Error: No results found"
|
||||
19
backend/packages/harness/deerflow/config/__init__.py
Normal file
19
backend/packages/harness/deerflow/config/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .app_config import get_app_config
|
||||
from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skills_config import SkillsConfig
|
||||
from .tracing_config import get_tracing_config, is_tracing_enabled
|
||||
|
||||
__all__ = [
|
||||
"get_app_config",
|
||||
"Paths",
|
||||
"get_paths",
|
||||
"SkillsConfig",
|
||||
"ExtensionsConfig",
|
||||
"get_extensions_config",
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
"is_tracing_enabled",
|
||||
]
|
||||
120
backend/packages/harness/deerflow/config/agents_config.py
Normal file
120
backend/packages/harness/deerflow/config/agents_config.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Configuration and loaders for custom agents."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SOUL_FILENAME = "SOUL.md"
|
||||
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for a custom agent."""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
model: str | None = None
|
||||
tool_groups: list[str] | None = None
|
||||
|
||||
|
||||
def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
"""Load the custom or default agent's config from its directory.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
|
||||
Returns:
|
||||
AgentConfig instance.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the agent directory or config.yaml does not exist.
|
||||
ValueError: If config.yaml cannot be parsed.
|
||||
"""
|
||||
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
if not AGENT_NAME_PATTERN.match(name):
|
||||
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
config_file = agent_dir / "config.yaml"
|
||||
|
||||
if not agent_dir.exists():
|
||||
raise FileNotFoundError(f"Agent directory not found: {agent_dir}")
|
||||
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Agent config not found: {config_file}")
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
data: dict[str, Any] = yaml.safe_load(f) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Failed to parse agent config {config_file}: {e}") from e
|
||||
|
||||
# Ensure name is set from directory name if not in file
|
||||
if "name" not in data:
|
||||
data["name"] = name
|
||||
|
||||
# Strip unknown fields before passing to Pydantic (e.g. legacy prompt_file)
|
||||
known_fields = set(AgentConfig.model_fields.keys())
|
||||
data = {k: v for k, v in data.items() if k in known_fields}
|
||||
|
||||
return AgentConfig(**data)
|
||||
|
||||
|
||||
def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
"""Read the SOUL.md file for a custom agent, if it exists.
|
||||
|
||||
SOUL.md defines the agent's personality, values, and behavioral guardrails.
|
||||
It is injected into the lead agent's system prompt as additional context.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent or None for the default agent.
|
||||
|
||||
Returns:
|
||||
The SOUL.md content as a string, or None if the file does not exist.
|
||||
"""
|
||||
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
|
||||
soul_path = agent_dir / SOUL_FILENAME
|
||||
if not soul_path.exists():
|
||||
return None
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
return content or None
|
||||
|
||||
|
||||
def list_custom_agents() -> list[AgentConfig]:
|
||||
"""Scan the agents directory and return all valid custom agents.
|
||||
|
||||
Returns:
|
||||
List of AgentConfig for each valid agent directory found.
|
||||
"""
|
||||
agents_dir = get_paths().agents_dir
|
||||
|
||||
if not agents_dir.exists():
|
||||
return []
|
||||
|
||||
agents: list[AgentConfig] = []
|
||||
|
||||
for entry in sorted(agents_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
config_file = entry / "config.yaml"
|
||||
if not config_file.exists():
|
||||
logger.debug(f"Skipping {entry.name}: no config.yaml")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(entry.name)
|
||||
agents.append(agent_cfg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping agent '{entry.name}': {e}")
|
||||
|
||||
return agents
|
||||
273
backend/packages/harness/deerflow/config/app_config.py
Normal file
273
backend/packages/harness/deerflow/config/app_config.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.memory_config import load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
from deerflow.config.subagents_config import load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import load_title_config_from_dict
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
||||
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
|
||||
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the config.yaml is in the current directory
|
||||
path = Path(os.getcwd()) / "config.yaml"
|
||||
if not path.exists():
|
||||
# Check if the config.yaml is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "config.yaml"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load config from YAML file.
|
||||
|
||||
See `resolve_config_path` for more details.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
AppConfig: The loaded config.
|
||||
"""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
# Check config version before processing
|
||||
cls._check_config_version(config_data, resolved_path)
|
||||
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
|
||||
# Load title config if present
|
||||
if "title" in config_data:
|
||||
load_title_config_from_dict(config_data["title"])
|
||||
|
||||
# Load summarization config if present
|
||||
if "summarization" in config_data:
|
||||
load_summarization_config_from_dict(config_data["summarization"])
|
||||
|
||||
# Load memory config if present
|
||||
if "memory" in config_data:
|
||||
load_memory_config_from_dict(config_data["memory"])
|
||||
|
||||
# Load subagents config if present
|
||||
if "subagents" in config_data:
|
||||
load_subagents_config_from_dict(config_data["subagents"])
|
||||
|
||||
# Load checkpointer config if present
|
||||
if "checkpointer" in config_data:
|
||||
load_checkpointer_config_from_dict(config_data["checkpointer"])
|
||||
|
||||
# Load extensions config separately (it's in a different file)
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
config_data["extensions"] = extensions_config.model_dump()
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||
|
||||
Emits a warning if the user's config_version is lower than the example's.
|
||||
Missing config_version is treated as version 0 (pre-versioning).
|
||||
"""
|
||||
try:
|
||||
user_version = int(config_data.get("config_version", 0))
|
||||
except (TypeError, ValueError):
|
||||
user_version = 0
|
||||
|
||||
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||
example_path = None
|
||||
search_dir = config_path.parent
|
||||
for _ in range(5): # search up to 5 levels
|
||||
candidate = search_dir / "config.example.yaml"
|
||||
if candidate.exists():
|
||||
example_path = candidate
|
||||
break
|
||||
parent = search_dir.parent
|
||||
if parent == search_dir:
|
||||
break
|
||||
search_dir = parent
|
||||
if example_path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
example_data = yaml.safe_load(f)
|
||||
raw = example_data.get("config_version", 0) if example_data else 0
|
||||
try:
|
||||
example_version = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
example_version = 0
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. "
|
||||
"Run `make config-upgrade` to merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: Any) -> Any:
|
||||
"""Recursively resolve environment variables in the config.
|
||||
|
||||
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
|
||||
|
||||
Args:
|
||||
config: The config to resolve environment variables in.
|
||||
|
||||
Returns:
|
||||
The config with environment variables resolved.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
if config.startswith("$"):
|
||||
env_value = os.getenv(config[1:])
|
||||
if env_value is None:
|
||||
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||
return env_value
|
||||
return config
|
||||
elif isinstance(config, dict):
|
||||
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [cls.resolve_env_variables(item) for item in config]
|
||||
return config
|
||||
|
||||
def get_model_config(self, name: str) -> ModelConfig | None:
|
||||
"""Get the model config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the model to get the config for.
|
||||
|
||||
Returns:
|
||||
The model config if found, otherwise None.
|
||||
"""
|
||||
return next((model for model in self.models if model.name == name), None)
|
||||
|
||||
def get_tool_config(self, name: str) -> ToolConfig | None:
|
||||
"""Get the tool config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool to get the config for.
|
||||
|
||||
Returns:
|
||||
The tool config if found, otherwise None.
|
||||
"""
|
||||
return next((tool for tool in self.tools if tool.name == name), None)
|
||||
|
||||
def get_tool_group_config(self, name: str) -> ToolGroupConfig | None:
|
||||
"""Get the tool group config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool group to get the config for.
|
||||
|
||||
Returns:
|
||||
The tool group config if found, otherwise None.
|
||||
"""
|
||||
return next((group for group in self.tool_groups if group.name == name), None)
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance.
|
||||
|
||||
Returns a cached singleton instance. Use `reload_app_config()` to reload
|
||||
from file, or `reset_app_config()` to clear the cache.
|
||||
"""
|
||||
global _app_config
|
||||
if _app_config is None:
|
||||
_app_config = AppConfig.from_file()
|
||||
return _app_config
|
||||
|
||||
|
||||
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Reload the config from file and update the cached instance.
|
||||
|
||||
This is useful when the config file has been modified and you want
|
||||
to pick up the changes without restarting the application.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file. If not provided,
|
||||
uses the default resolution strategy.
|
||||
|
||||
Returns:
|
||||
The newly loaded AppConfig instance.
|
||||
"""
|
||||
global _app_config
|
||||
_app_config = AppConfig.from_file(config_path)
|
||||
return _app_config
|
||||
|
||||
|
||||
def reset_app_config() -> None:
|
||||
"""Reset the cached config instance.
|
||||
|
||||
This clears the singleton cache, causing the next call to
|
||||
`get_app_config()` to reload from file. Useful for testing
|
||||
or when switching between different configurations.
|
||||
"""
|
||||
global _app_config
|
||||
_app_config = None
|
||||
|
||||
|
||||
def set_app_config(config: AppConfig) -> None:
|
||||
"""Set a custom config instance.
|
||||
|
||||
This allows injecting a custom or mock config for testing purposes.
|
||||
|
||||
Args:
|
||||
config: The AppConfig instance to use.
|
||||
"""
|
||||
global _app_config
|
||||
_app_config = config
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Configuration for LangGraph checkpointer."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
CheckpointerType = Literal["memory", "sqlite", "postgres"]
|
||||
|
||||
|
||||
class CheckpointerConfig(BaseModel):
|
||||
"""Configuration for LangGraph state persistence checkpointer."""
|
||||
|
||||
type: CheckpointerType = Field(
|
||||
description="Checkpointer backend type. "
|
||||
"'memory' is in-process only (lost on restart). "
|
||||
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
|
||||
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="Connection string for sqlite (file path) or postgres (DSN). "
|
||||
"Required for sqlite and postgres types. "
|
||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance — None means no checkpointer is configured.
|
||||
_checkpointer_config: CheckpointerConfig | None = None
|
||||
|
||||
|
||||
def get_checkpointer_config() -> CheckpointerConfig | None:
|
||||
"""Get the current checkpointer configuration, or None if not configured."""
|
||||
return _checkpointer_config
|
||||
|
||||
|
||||
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
||||
"""Set the checkpointer configuration."""
|
||||
global _checkpointer_config
|
||||
_checkpointer_config = config
|
||||
|
||||
|
||||
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load checkpointer configuration from a dictionary."""
|
||||
global _checkpointer_config
|
||||
_checkpointer_config = CheckpointerConfig(**config_dict)
|
||||
258
backend/packages/harness/deerflow/config/extensions_config.py
Normal file
258
backend/packages/harness/deerflow/config/extensions_config.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Unified extensions configuration for MCP servers and skills."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class McpOAuthConfig(BaseModel):
|
||||
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
|
||||
token_url: str = Field(description="OAuth token endpoint URL")
|
||||
grant_type: Literal["client_credentials", "refresh_token"] = Field(
|
||||
default="client_credentials",
|
||||
description="OAuth grant type",
|
||||
)
|
||||
client_id: str | None = Field(default=None, description="OAuth client ID")
|
||||
client_secret: str | None = Field(default=None, description="OAuth client secret")
|
||||
refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)")
|
||||
scope: str | None = Field(default=None, description="OAuth scope")
|
||||
audience: str | None = Field(default=None, description="OAuth audience (provider-specific)")
|
||||
token_field: str = Field(default="access_token", description="Field name containing access token in token response")
|
||||
token_type_field: str = Field(default="token_type", description="Field name containing token type in token response")
|
||||
expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response")
|
||||
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
|
||||
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
|
||||
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class McpServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
|
||||
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
|
||||
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
|
||||
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
|
||||
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
|
||||
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
|
||||
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
|
||||
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
|
||||
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SkillStateConfig(BaseModel):
|
||||
"""Configuration for a single skill's state."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
||||
|
||||
|
||||
class ExtensionsConfig(BaseModel):
|
||||
"""Unified configuration for MCP servers and skills."""
|
||||
|
||||
mcp_servers: dict[str, McpServerConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of MCP server name to configuration",
|
||||
alias="mcpServers",
|
||||
)
|
||||
skills: dict[str, SkillStateConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of skill name to state configuration",
|
||||
)
|
||||
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
|
||||
"""Resolve the extensions config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory.
|
||||
4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found.
|
||||
5. If not found, return None (extensions are optional).
|
||||
|
||||
Args:
|
||||
config_path: Optional path to extensions config file.
|
||||
|
||||
Returns:
|
||||
Path to the extensions config file if found, otherwise None.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Extensions config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"))
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the extensions_config.json is in the current directory
|
||||
path = Path(os.getcwd()) / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Check if the extensions_config.json is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Backward compatibility: check for mcp_config.json
|
||||
path = Path(os.getcwd()) / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
path = Path(os.getcwd()).parent / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Extensions are optional, so return None if not found
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> "ExtensionsConfig":
|
||||
"""Load extensions config from JSON file.
|
||||
|
||||
See `resolve_config_path` for more details.
|
||||
|
||||
Args:
|
||||
config_path: Path to the extensions config file.
|
||||
|
||||
Returns:
|
||||
ExtensionsConfig: The loaded config, or empty config if file not found.
|
||||
"""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
if resolved_path is None:
|
||||
# Return empty config if extensions config file is not found
|
||||
return cls(mcp_servers={}, skills={})
|
||||
|
||||
try:
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
cls.resolve_env_variables(config_data)
|
||||
return cls.model_validate(config_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Extensions config file at {resolved_path} is not valid JSON: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load extensions config from {resolved_path}: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively resolve environment variables in the config.
|
||||
|
||||
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
|
||||
|
||||
Args:
|
||||
config: The config to resolve environment variables in.
|
||||
|
||||
Returns:
|
||||
The config with environment variables resolved.
|
||||
"""
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str):
|
||||
if value.startswith("$"):
|
||||
env_value = os.getenv(value[1:])
|
||||
if env_value is None:
|
||||
# Unresolved placeholder — store empty string so downstream
|
||||
# consumers (e.g. MCP servers) don't receive the literal "$VAR"
|
||||
# token as an actual environment value.
|
||||
config[key] = ""
|
||||
else:
|
||||
config[key] = env_value
|
||||
else:
|
||||
config[key] = value
|
||||
elif isinstance(value, dict):
|
||||
config[key] = cls.resolve_env_variables(value)
|
||||
elif isinstance(value, list):
|
||||
config[key] = [cls.resolve_env_variables(item) if isinstance(item, dict) else item for item in value]
|
||||
return config
|
||||
|
||||
def get_enabled_mcp_servers(self) -> dict[str, McpServerConfig]:
|
||||
"""Get only the enabled MCP servers.
|
||||
|
||||
Returns:
|
||||
Dictionary of enabled MCP servers.
|
||||
"""
|
||||
return {name: config for name, config in self.mcp_servers.items() if config.enabled}
|
||||
|
||||
def is_skill_enabled(self, skill_name: str, skill_category: str) -> bool:
|
||||
"""Check if a skill is enabled.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
skill_category: Category of the skill
|
||||
|
||||
Returns:
|
||||
True if enabled, False otherwise
|
||||
"""
|
||||
skill_config = self.skills.get(skill_name)
|
||||
if skill_config is None:
|
||||
# Default to enable for public & custom skill
|
||||
return skill_category in ("public", "custom")
|
||||
return skill_config.enabled
|
||||
|
||||
|
||||
_extensions_config: ExtensionsConfig | None = None
|
||||
|
||||
|
||||
def get_extensions_config() -> ExtensionsConfig:
|
||||
"""Get the extensions config instance.
|
||||
|
||||
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
|
||||
from file, or `reset_extensions_config()` to clear the cache.
|
||||
|
||||
Returns:
|
||||
The cached ExtensionsConfig instance.
|
||||
"""
|
||||
global _extensions_config
|
||||
if _extensions_config is None:
|
||||
_extensions_config = ExtensionsConfig.from_file()
|
||||
return _extensions_config
|
||||
|
||||
|
||||
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
|
||||
"""Reload the extensions config from file and update the cached instance.
|
||||
|
||||
This is useful when the config file has been modified and you want
|
||||
to pick up the changes without restarting the application.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to extensions config file. If not provided,
|
||||
uses the default resolution strategy.
|
||||
|
||||
Returns:
|
||||
The newly loaded ExtensionsConfig instance.
|
||||
"""
|
||||
global _extensions_config
|
||||
_extensions_config = ExtensionsConfig.from_file(config_path)
|
||||
return _extensions_config
|
||||
|
||||
|
||||
def reset_extensions_config() -> None:
|
||||
"""Reset the cached extensions config instance.
|
||||
|
||||
This clears the singleton cache, causing the next call to
|
||||
`get_extensions_config()` to reload from file. Useful for testing
|
||||
or when switching between different configurations.
|
||||
"""
|
||||
global _extensions_config
|
||||
_extensions_config = None
|
||||
|
||||
|
||||
def set_extensions_config(config: ExtensionsConfig) -> None:
|
||||
"""Set a custom extensions config instance.
|
||||
|
||||
This allows injecting a custom or mock config for testing purposes.
|
||||
|
||||
Args:
|
||||
config: The ExtensionsConfig instance to use.
|
||||
"""
|
||||
global _extensions_config
|
||||
_extensions_config = config
|
||||
78
backend/packages/harness/deerflow/config/memory_config.py
Normal file
78
backend/packages/harness/deerflow/config/memory_config.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Configuration for memory mechanism."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""Configuration for global memory mechanism."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable memory mechanism",
|
||||
)
|
||||
storage_path: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Path to store memory data. "
|
||||
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
|
||||
"Absolute paths are used as-is. "
|
||||
"Relative paths are resolved against `Paths.base_dir` "
|
||||
"(not the backend working directory). "
|
||||
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
||||
"the file will now be resolved as `{base_dir}/.deer-flow/memory.json`; "
|
||||
"migrate existing data or use an absolute path to preserve the old location."
|
||||
),
|
||||
)
|
||||
debounce_seconds: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=300,
|
||||
description="Seconds to wait before processing queued updates (debounce)",
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name to use for memory updates (None = use default model)",
|
||||
)
|
||||
max_facts: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=500,
|
||||
description="Maximum number of facts to store",
|
||||
)
|
||||
fact_confidence_threshold: float = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence threshold for storing facts",
|
||||
)
|
||||
injection_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to inject memory into system prompt",
|
||||
)
|
||||
max_injection_tokens: int = Field(
|
||||
default=2000,
|
||||
ge=100,
|
||||
le=8000,
|
||||
description="Maximum tokens to use for memory injection",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_memory_config: MemoryConfig = MemoryConfig()
|
||||
|
||||
|
||||
def get_memory_config() -> MemoryConfig:
|
||||
"""Get the current memory configuration."""
|
||||
return _memory_config
|
||||
|
||||
|
||||
def set_memory_config(config: MemoryConfig) -> None:
|
||||
"""Set the memory configuration."""
|
||||
global _memory_config
|
||||
_memory_config = config
|
||||
|
||||
|
||||
def load_memory_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load memory configuration from a dictionary."""
|
||||
global _memory_config
|
||||
_memory_config = MemoryConfig(**config_dict)
|
||||
29
backend/packages/harness/deerflow/config/model_config.py
Normal file
29
backend/packages/harness/deerflow/config/model_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Config section for a model"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the model")
|
||||
display_name: str | None = Field(..., default_factory=lambda: None, description="Display name for the model")
|
||||
description: str | None = Field(..., default_factory=lambda: None, description="Description for the model")
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
supports_thinking: bool = Field(default_factory=lambda: False, description="Whether the model supports thinking")
|
||||
supports_reasoning_effort: bool = Field(default_factory=lambda: False, description="Whether the model supports reasoning effort")
|
||||
when_thinking_enabled: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is enabled",
|
||||
)
|
||||
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
|
||||
thinking: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description=(
|
||||
"Thinking settings for the model. If provided, these settings will be passed to the model when thinking is enabled. "
|
||||
"This is a shortcut for `when_thinking_enabled` and will be merged with `when_thinking_enabled` if both are provided."
|
||||
),
|
||||
)
|
||||
216
backend/packages/harness/deerflow/config/paths.py
Normal file
216
backend/packages/harness/deerflow/config/paths.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# Virtual path prefix seen by agents inside the sandbox
|
||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
class Paths:
|
||||
"""
|
||||
Centralized path configuration for DeerFlow application data.
|
||||
|
||||
Directory layout (host side):
|
||||
{base_dir}/
|
||||
├── memory.json
|
||||
├── USER.md <-- global user profile (injected into all agents)
|
||||
├── agents/
|
||||
│ └── {agent_name}/
|
||||
│ ├── config.yaml
|
||||
│ ├── SOUL.md <-- agent personality/identity (injected alongside lead prompt)
|
||||
│ └── memory.json
|
||||
└── threads/
|
||||
└── {thread_id}/
|
||||
└── user-data/ <-- mounted as /mnt/user-data/ inside sandbox
|
||||
├── workspace/ <-- /mnt/user-data/workspace/
|
||||
├── uploads/ <-- /mnt/user-data/uploads/
|
||||
└── outputs/ <-- /mnt/user-data/outputs/
|
||||
|
||||
BaseDir resolution (in priority order):
|
||||
1. Constructor argument `base_dir`
|
||||
2. DEER_FLOW_HOME environment variable
|
||||
3. Local dev fallback: cwd/.deer-flow (when cwd is the backend/ dir)
|
||||
4. Default: $HOME/.deer-flow
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None) -> None:
|
||||
self._base_dir = Path(base_dir).resolve() if base_dir is not None else None
|
||||
|
||||
@property
|
||||
def host_base_dir(self) -> Path:
|
||||
"""Host-visible base dir for Docker volume mount sources.
|
||||
|
||||
When running inside Docker with a mounted Docker socket (DooD), the Docker
|
||||
daemon runs on the host and resolves mount paths against the host filesystem.
|
||||
Set DEER_FLOW_HOST_BASE_DIR to the host-side path that corresponds to this
|
||||
container's base_dir so that sandbox container volume mounts work correctly.
|
||||
|
||||
Falls back to base_dir when the env var is not set (native/local execution).
|
||||
"""
|
||||
if env := os.getenv("DEER_FLOW_HOST_BASE_DIR"):
|
||||
return Path(env)
|
||||
return self.base_dir
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
"""Root directory for all application data."""
|
||||
if self._base_dir is not None:
|
||||
return self._base_dir
|
||||
|
||||
if env_home := os.getenv("DEER_FLOW_HOME"):
|
||||
return Path(env_home).resolve()
|
||||
|
||||
cwd = Path.cwd()
|
||||
if cwd.name == "backend" or (cwd / "pyproject.toml").exists():
|
||||
return cwd / ".deer-flow"
|
||||
|
||||
return Path.home() / ".deer-flow"
|
||||
|
||||
@property
|
||||
def memory_file(self) -> Path:
|
||||
"""Path to the persisted memory file: `{base_dir}/memory.json`."""
|
||||
return self.base_dir / "memory.json"
|
||||
|
||||
@property
|
||||
def user_md_file(self) -> Path:
|
||||
"""Path to the global user profile file: `{base_dir}/USER.md`."""
|
||||
return self.base_dir / "USER.md"
|
||||
|
||||
@property
|
||||
def agents_dir(self) -> Path:
|
||||
"""Root directory for all custom agents: `{base_dir}/agents/`."""
|
||||
return self.base_dir / "agents"
|
||||
|
||||
def agent_dir(self, name: str) -> Path:
|
||||
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
|
||||
return self.agents_dir / name.lower()
|
||||
|
||||
def agent_memory_file(self, name: str) -> Path:
|
||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
return self.agent_dir(name) / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
|
||||
|
||||
This directory contains a `user-data/` subdirectory that is mounted
|
||||
as `/mnt/user-data/` inside the sandbox.
|
||||
|
||||
Raises:
|
||||
ValueError: If `thread_id` contains unsafe characters (path separators
|
||||
or `..`) that could cause directory traversal.
|
||||
"""
|
||||
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
||||
raise ValueError(f"Invalid thread_id {thread_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
||||
return self.base_dir / "threads" / thread_id
|
||||
|
||||
def sandbox_work_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the agent's workspace directory.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
||||
Sandbox: `/mnt/user-data/workspace/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "workspace"
|
||||
|
||||
def sandbox_uploads_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for user-uploaded files.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
||||
Sandbox: `/mnt/user-data/uploads/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "uploads"
|
||||
|
||||
def sandbox_outputs_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for agent-generated artifacts.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
||||
Sandbox: `/mnt/user-data/outputs/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "outputs"
|
||||
|
||||
def sandbox_user_data_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the user-data root.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
||||
Sandbox: `/mnt/user-data/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data"
|
||||
|
||||
def ensure_thread_dirs(self, thread_id: str) -> None:
|
||||
"""Create all standard sandbox directories for a thread.
|
||||
|
||||
Directories are created with mode 0o777 so that sandbox containers
|
||||
(which may run as a different UID than the host backend process) can
|
||||
write to the volume-mounted paths without "Permission denied" errors.
|
||||
The explicit chmod() call is necessary because Path.mkdir(mode=...) is
|
||||
subject to the process umask and may not yield the intended permissions.
|
||||
"""
|
||||
for d in [
|
||||
self.sandbox_work_dir(thread_id),
|
||||
self.sandbox_uploads_dir(thread_id),
|
||||
self.sandbox_outputs_dir(thread_id),
|
||||
]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
d.chmod(0o777)
|
||||
|
||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
|
||||
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
||||
``/mnt/user-data/outputs/report.pdf``.
|
||||
Leading slashes are stripped before matching.
|
||||
|
||||
Returns:
|
||||
The resolved absolute host filesystem path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path does not start with the expected virtual
|
||||
prefix or a path-traversal attempt is detected.
|
||||
"""
|
||||
stripped = virtual_path.lstrip("/")
|
||||
prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
# Require an exact segment-boundary match to avoid prefix confusion
|
||||
# (e.g. reject paths like "mnt/user-dataX/...").
|
||||
if stripped != prefix and not stripped.startswith(prefix + "/"):
|
||||
raise ValueError(f"Path must start with /{prefix}")
|
||||
|
||||
relative = stripped[len(prefix) :].lstrip("/")
|
||||
base = self.sandbox_user_data_dir(thread_id).resolve()
|
||||
actual = (base / relative).resolve()
|
||||
|
||||
try:
|
||||
actual.relative_to(base)
|
||||
except ValueError:
|
||||
raise ValueError("Access denied: path traversal detected")
|
||||
|
||||
return actual
|
||||
|
||||
|
||||
# ── Singleton ────────────────────────────────────────────────────────────
|
||||
|
||||
_paths: Paths | None = None
|
||||
|
||||
|
||||
def get_paths() -> Paths:
|
||||
"""Return the global Paths singleton (lazy-initialized)."""
|
||||
global _paths
|
||||
if _paths is None:
|
||||
_paths = Paths()
|
||||
return _paths
|
||||
|
||||
|
||||
def resolve_path(path: str) -> Path:
|
||||
"""Resolve *path* to an absolute ``Path``.
|
||||
|
||||
Relative paths are resolved relative to the application base directory.
|
||||
Absolute paths are returned as-is (after normalisation).
|
||||
"""
|
||||
p = Path(path)
|
||||
if not p.is_absolute():
|
||||
p = get_paths().base_dir / path
|
||||
return p.resolve()
|
||||
61
backend/packages/harness/deerflow/config/sandbox_config.py
Normal file
61
backend/packages/harness/deerflow/config/sandbox_config.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VolumeMountConfig(BaseModel):
|
||||
"""Configuration for a volume mount."""
|
||||
|
||||
host_path: str = Field(..., description="Path on the host machine")
|
||||
container_path: str = Field(..., description="Path inside the container")
|
||||
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
||||
|
||||
|
||||
class SandboxConfig(BaseModel):
|
||||
"""Config section for a sandbox.
|
||||
|
||||
Common options:
|
||||
use: Class path of the sandbox provider (required)
|
||||
|
||||
AioSandboxProvider specific options:
|
||||
image: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest)
|
||||
port: Base port for sandbox containers (default: 8080)
|
||||
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
|
||||
container_prefix: Prefix for container names (default: deer-flow-sandbox)
|
||||
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
|
||||
mounts: List of volume mounts to share directories with the container
|
||||
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
|
||||
"""
|
||||
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Class path of the sandbox provider (e.g. deerflow.sandbox.local:LocalSandboxProvider)",
|
||||
)
|
||||
image: str | None = Field(
|
||||
default=None,
|
||||
description="Docker image to use for the sandbox container",
|
||||
)
|
||||
port: int | None = Field(
|
||||
default=None,
|
||||
description="Base port for sandbox containers",
|
||||
)
|
||||
replicas: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.",
|
||||
)
|
||||
container_prefix: str | None = Field(
|
||||
default=None,
|
||||
description="Prefix for container names",
|
||||
)
|
||||
idle_timeout: int | None = Field(
|
||||
default=None,
|
||||
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
|
||||
)
|
||||
mounts: list[VolumeMountConfig] = Field(
|
||||
default_factory=list,
|
||||
description="List of volume mounts to share directories between host and container",
|
||||
)
|
||||
environment: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Environment variables to inject into the sandbox container. Values starting with $ will be resolved from host environment variables.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
49
backend/packages/harness/deerflow/config/skills_config.py
Normal file
49
backend/packages/harness/deerflow/config/skills_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
|
||||
)
|
||||
container_path: str = Field(
|
||||
default="/mnt/skills",
|
||||
description="Path where skills are mounted in the sandbox container",
|
||||
)
|
||||
|
||||
def get_skills_path(self) -> Path:
|
||||
"""
|
||||
Get the resolved skills directory path.
|
||||
|
||||
Returns:
|
||||
Path to the skills directory
|
||||
"""
|
||||
if self.path:
|
||||
# Use configured path (can be absolute or relative)
|
||||
path = Path(self.path)
|
||||
if not path.is_absolute():
|
||||
# If relative, resolve from current working directory
|
||||
path = Path.cwd() / path
|
||||
return path.resolve()
|
||||
else:
|
||||
# Default: ../skills relative to backend directory
|
||||
from deerflow.skills.loader import get_skills_root_path
|
||||
|
||||
return get_skills_root_path()
|
||||
|
||||
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
|
||||
"""
|
||||
Get the full container path for a specific skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill (directory name)
|
||||
category: Category of the skill (public or custom)
|
||||
|
||||
Returns:
|
||||
Full path to the skill in the container
|
||||
"""
|
||||
return f"{self.container_path}/{category}/{skill_name}"
|
||||
65
backend/packages/harness/deerflow/config/subagents_config.py
Normal file
65
backend/packages/harness/deerflow/config/subagents_config.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Configuration for the subagent system loaded from config.yaml."""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentOverrideConfig(BaseModel):
|
||||
"""Per-agent configuration overrides."""
|
||||
|
||||
timeout_seconds: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Timeout in seconds for this subagent (None = use global default)",
|
||||
)
|
||||
|
||||
|
||||
class SubagentsAppConfig(BaseModel):
|
||||
"""Configuration for the subagent system."""
|
||||
|
||||
timeout_seconds: int = Field(
|
||||
default=900,
|
||||
ge=1,
|
||||
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
|
||||
)
|
||||
agents: dict[str, SubagentOverrideConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Per-agent configuration overrides keyed by agent name",
|
||||
)
|
||||
|
||||
def get_timeout_for(self, agent_name: str) -> int:
|
||||
"""Get the effective timeout for a specific agent.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the subagent.
|
||||
|
||||
Returns:
|
||||
The timeout in seconds, using per-agent override if set, otherwise global default.
|
||||
"""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.timeout_seconds is not None:
|
||||
return override.timeout_seconds
|
||||
return self.timeout_seconds
|
||||
|
||||
|
||||
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
|
||||
|
||||
|
||||
def get_subagents_app_config() -> SubagentsAppConfig:
|
||||
"""Get the current subagents configuration."""
|
||||
return _subagents_config
|
||||
|
||||
|
||||
def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load subagents configuration from a dictionary."""
|
||||
global _subagents_config
|
||||
_subagents_config = SubagentsAppConfig(**config_dict)
|
||||
|
||||
overrides_summary = {name: f"{override.timeout_seconds}s" for name, override in _subagents_config.agents.items() if override.timeout_seconds is not None}
|
||||
if overrides_summary:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, per-agent overrides={overrides_summary}")
|
||||
else:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, no per-agent overrides")
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Configuration for conversation summarization."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ContextSizeType = Literal["fraction", "tokens", "messages"]
|
||||
|
||||
|
||||
class ContextSize(BaseModel):
|
||||
"""Context size specification for trigger or keep parameters."""
|
||||
|
||||
type: ContextSizeType = Field(description="Type of context size specification")
|
||||
value: int | float = Field(description="Value for the context size specification")
|
||||
|
||||
def to_tuple(self) -> tuple[ContextSizeType, int | float]:
|
||||
"""Convert to tuple format expected by SummarizationMiddleware."""
|
||||
return (self.type, self.value)
|
||||
|
||||
|
||||
class SummarizationConfig(BaseModel):
|
||||
"""Configuration for automatic conversation summarization."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable automatic conversation summarization",
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name to use for summarization (None = use a lightweight model)",
|
||||
)
|
||||
trigger: ContextSize | list[ContextSize] | None = Field(
|
||||
default=None,
|
||||
description="One or more thresholds that trigger summarization. When any threshold is met, summarization runs. "
|
||||
"Examples: {'type': 'messages', 'value': 50} triggers at 50 messages, "
|
||||
"{'type': 'tokens', 'value': 4000} triggers at 4000 tokens, "
|
||||
"{'type': 'fraction', 'value': 0.8} triggers at 80% of model's max input tokens",
|
||||
)
|
||||
keep: ContextSize = Field(
|
||||
default_factory=lambda: ContextSize(type="messages", value=20),
|
||||
description="Context retention policy after summarization. Specifies how much history to preserve. "
|
||||
"Examples: {'type': 'messages', 'value': 20} keeps 20 messages, "
|
||||
"{'type': 'tokens', 'value': 3000} keeps 3000 tokens, "
|
||||
"{'type': 'fraction', 'value': 0.3} keeps 30% of model's max input tokens",
|
||||
)
|
||||
trim_tokens_to_summarize: int | None = Field(
|
||||
default=4000,
|
||||
description="Maximum tokens to keep when preparing messages for summarization. Pass null to skip trimming.",
|
||||
)
|
||||
summary_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_summarization_config: SummarizationConfig = SummarizationConfig()
|
||||
|
||||
|
||||
def get_summarization_config() -> SummarizationConfig:
|
||||
"""Get the current summarization configuration."""
|
||||
return _summarization_config
|
||||
|
||||
|
||||
def set_summarization_config(config: SummarizationConfig) -> None:
|
||||
"""Set the summarization configuration."""
|
||||
global _summarization_config
|
||||
_summarization_config = config
|
||||
|
||||
|
||||
def load_summarization_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load summarization configuration from a dictionary."""
|
||||
global _summarization_config
|
||||
_summarization_config = SummarizationConfig(**config_dict)
|
||||
53
backend/packages/harness/deerflow/config/title_config.py
Normal file
53
backend/packages/harness/deerflow/config/title_config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Configuration for automatic thread title generation."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TitleConfig(BaseModel):
|
||||
"""Configuration for automatic thread title generation."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable automatic title generation",
|
||||
)
|
||||
max_words: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum number of words in the generated title",
|
||||
)
|
||||
max_chars: int = Field(
|
||||
default=60,
|
||||
ge=10,
|
||||
le=200,
|
||||
description="Maximum number of characters in the generated title",
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name to use for title generation (None = use default model)",
|
||||
)
|
||||
prompt_template: str = Field(
|
||||
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
|
||||
description="Prompt template for title generation",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_title_config: TitleConfig = TitleConfig()
|
||||
|
||||
|
||||
def get_title_config() -> TitleConfig:
|
||||
"""Get the current title configuration."""
|
||||
return _title_config
|
||||
|
||||
|
||||
def set_title_config(config: TitleConfig) -> None:
|
||||
"""Set the title configuration."""
|
||||
global _title_config
|
||||
_title_config = config
|
||||
|
||||
|
||||
def load_title_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load title configuration from a dictionary."""
|
||||
global _title_config
|
||||
_title_config = TitleConfig(**config_dict)
|
||||
20
backend/packages/harness/deerflow/config/tool_config.py
Normal file
20
backend/packages/harness/deerflow/config/tool_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ToolGroupConfig(BaseModel):
|
||||
"""Config section for a tool group"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the tool group")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""Config section for a tool"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the tool")
|
||||
group: str = Field(..., description="Group name for the tool")
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
|
||||
)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
94
backend/packages/harness/deerflow/config/tracing_config.py
Normal file
94
backend/packages/harness/deerflow/config/tracing_config.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
class TracingConfig(BaseModel):
|
||||
"""Configuration for LangSmith tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
api_key: str | None = Field(...)
|
||||
project: str = Field(...)
|
||||
endpoint: str = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if tracing is fully configured (enabled and has API key)."""
|
||||
return self.enabled and bool(self.api_key)
|
||||
|
||||
|
||||
_tracing_config: TracingConfig | None = None
|
||||
|
||||
|
||||
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _env_flag_preferred(*names: str) -> bool:
|
||||
"""Return the boolean value of the first env var that is present and non-empty.
|
||||
|
||||
Accepted truthy values (case-insensitive): ``1``, ``true``, ``yes``, ``on``.
|
||||
Any other non-empty value is treated as falsy. If none of the named
|
||||
variables is set, returns ``False``.
|
||||
"""
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value is not None and value.strip():
|
||||
return value.strip().lower() in _TRUTHY_VALUES
|
||||
return False
|
||||
|
||||
|
||||
def _first_env_value(*names: str) -> str | None:
|
||||
"""Return the first non-empty environment value from candidate names."""
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def get_tracing_config() -> TracingConfig:
|
||||
"""Get the current tracing configuration from environment variables.
|
||||
|
||||
``LANGSMITH_*`` variables take precedence over their legacy ``LANGCHAIN_*``
|
||||
counterparts. For boolean flags (``enabled``), the *first* variable that is
|
||||
present and non-empty in the priority list is the sole authority – its value
|
||||
is parsed and returned without consulting the remaining candidates. Accepted
|
||||
truthy values are ``1``, ``true``, ``yes``, and ``on`` (case-insensitive);
|
||||
any other non-empty value is treated as falsy.
|
||||
|
||||
Priority order:
|
||||
enabled : LANGSMITH_TRACING > LANGCHAIN_TRACING_V2 > LANGCHAIN_TRACING
|
||||
api_key : LANGSMITH_API_KEY > LANGCHAIN_API_KEY
|
||||
project : LANGSMITH_PROJECT > LANGCHAIN_PROJECT (default: "deer-flow")
|
||||
endpoint : LANGSMITH_ENDPOINT > LANGCHAIN_ENDPOINT (default: https://api.smith.langchain.com)
|
||||
|
||||
Returns:
|
||||
TracingConfig with current settings.
|
||||
"""
|
||||
global _tracing_config
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
with _config_lock:
|
||||
if _tracing_config is not None: # Double-check after acquiring lock
|
||||
return _tracing_config
|
||||
_tracing_config = TracingConfig(
|
||||
# Keep compatibility with both legacy LANGCHAIN_* and newer LANGSMITH_* variables.
|
||||
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
|
||||
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
|
||||
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
|
||||
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
|
||||
)
|
||||
return _tracing_config
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if LangSmith tracing is enabled and configured.
|
||||
Returns:
|
||||
True if tracing is enabled and has an API key.
|
||||
"""
|
||||
return get_tracing_config().is_configured
|
||||
14
backend/packages/harness/deerflow/mcp/__init__.py
Normal file
14
backend/packages/harness/deerflow/mcp/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
|
||||
|
||||
from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
|
||||
from .client import build_server_params, build_servers_config
|
||||
from .tools import get_mcp_tools
|
||||
|
||||
__all__ = [
|
||||
"build_server_params",
|
||||
"build_servers_config",
|
||||
"get_mcp_tools",
|
||||
"initialize_mcp_tools",
|
||||
"get_cached_mcp_tools",
|
||||
"reset_mcp_tools_cache",
|
||||
]
|
||||
138
backend/packages/harness/deerflow/mcp/cache.py
Normal file
138
backend/packages/harness/deerflow/mcp/cache.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Cache for MCP tools to avoid repeated loading."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_mcp_tools_cache: list[BaseTool] | None = None
|
||||
_cache_initialized = False
|
||||
_initialization_lock = asyncio.Lock()
|
||||
_config_mtime: float | None = None # Track config file modification time
|
||||
|
||||
|
||||
def _get_config_mtime() -> float | None:
|
||||
"""Get the modification time of the extensions config file.
|
||||
|
||||
Returns:
|
||||
The modification time as a float, or None if the file doesn't exist.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
if config_path and config_path.exists():
|
||||
return os.path.getmtime(config_path)
|
||||
return None
|
||||
|
||||
|
||||
def _is_cache_stale() -> bool:
|
||||
"""Check if the cache is stale due to config file changes.
|
||||
|
||||
Returns:
|
||||
True if the cache should be invalidated, False otherwise.
|
||||
"""
|
||||
global _config_mtime
|
||||
|
||||
if not _cache_initialized:
|
||||
return False # Not initialized yet, not stale
|
||||
|
||||
current_mtime = _get_config_mtime()
|
||||
|
||||
# If we couldn't get mtime before or now, assume not stale
|
||||
if _config_mtime is None or current_mtime is None:
|
||||
return False
|
||||
|
||||
# If the config file has been modified since we cached, it's stale
|
||||
if current_mtime > _config_mtime:
|
||||
logger.info(f"MCP config file has been modified (mtime: {_config_mtime} -> {current_mtime}), cache is stale")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def initialize_mcp_tools() -> list[BaseTool]:
|
||||
"""Initialize and cache MCP tools.
|
||||
|
||||
This should be called once at application startup.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
|
||||
async with _initialization_lock:
|
||||
if _cache_initialized:
|
||||
logger.info("MCP tools already initialized")
|
||||
return _mcp_tools_cache or []
|
||||
|
||||
from deerflow.mcp.tools import get_mcp_tools
|
||||
|
||||
logger.info("Initializing MCP tools...")
|
||||
_mcp_tools_cache = await get_mcp_tools()
|
||||
_cache_initialized = True
|
||||
_config_mtime = _get_config_mtime() # Record config file mtime
|
||||
logger.info(f"MCP tools initialized: {len(_mcp_tools_cache)} tool(s) loaded (config mtime: {_config_mtime})")
|
||||
|
||||
return _mcp_tools_cache
|
||||
|
||||
|
||||
def get_cached_mcp_tools() -> list[BaseTool]:
|
||||
"""Get cached MCP tools with lazy initialization.
|
||||
|
||||
If tools are not initialized, automatically initializes them.
|
||||
This ensures MCP tools work in both FastAPI and LangGraph Studio contexts.
|
||||
|
||||
Also checks if the config file has been modified since last initialization,
|
||||
and re-initializes if needed. This ensures that changes made through the
|
||||
Gateway API (which runs in a separate process) are reflected in the
|
||||
LangGraph Server.
|
||||
|
||||
Returns:
|
||||
List of cached MCP tools.
|
||||
"""
|
||||
global _cache_initialized
|
||||
|
||||
# Check if cache is stale due to config file changes
|
||||
if _is_cache_stale():
|
||||
logger.info("MCP cache is stale, resetting for re-initialization...")
|
||||
reset_mcp_tools_cache()
|
||||
|
||||
if not _cache_initialized:
|
||||
logger.info("MCP tools not initialized, performing lazy initialization...")
|
||||
try:
|
||||
# Try to initialize in the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is already running (e.g., in LangGraph Studio),
|
||||
# we need to create a new loop in a thread
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, initialize_mcp_tools())
|
||||
future.result()
|
||||
else:
|
||||
# If no loop is running, we can use the current loop
|
||||
loop.run_until_complete(initialize_mcp_tools())
|
||||
except RuntimeError:
|
||||
# No event loop exists, create one
|
||||
asyncio.run(initialize_mcp_tools())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to lazy-initialize MCP tools: {e}")
|
||||
return []
|
||||
|
||||
return _mcp_tools_cache or []
|
||||
|
||||
|
||||
def reset_mcp_tools_cache() -> None:
|
||||
"""Reset the MCP tools cache.
|
||||
|
||||
This is useful for testing or when you want to reload MCP tools.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
_mcp_tools_cache = None
|
||||
_cache_initialized = False
|
||||
_config_mtime = None
|
||||
logger.info("MCP tools cache reset")
|
||||
68
backend/packages/harness/deerflow/mcp/client.py
Normal file
68
backend/packages/harness/deerflow/mcp/client.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""MCP client using langchain-mcp-adapters."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_server_params(server_name: str, config: McpServerConfig) -> dict[str, Any]:
|
||||
"""Build server parameters for MultiServerMCPClient.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server.
|
||||
config: Configuration for the MCP server.
|
||||
|
||||
Returns:
|
||||
Dictionary of server parameters for langchain-mcp-adapters.
|
||||
"""
|
||||
transport_type = config.type or "stdio"
|
||||
params: dict[str, Any] = {"transport": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not config.command:
|
||||
raise ValueError(f"MCP server '{server_name}' with stdio transport requires 'command' field")
|
||||
params["command"] = config.command
|
||||
params["args"] = config.args
|
||||
# Add environment variables if present
|
||||
if config.env:
|
||||
params["env"] = config.env
|
||||
elif transport_type in ("sse", "http"):
|
||||
if not config.url:
|
||||
raise ValueError(f"MCP server '{server_name}' with {transport_type} transport requires 'url' field")
|
||||
params["url"] = config.url
|
||||
# Add headers if present
|
||||
if config.headers:
|
||||
params["headers"] = config.headers
|
||||
else:
|
||||
raise ValueError(f"MCP server '{server_name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def build_servers_config(extensions_config: ExtensionsConfig) -> dict[str, dict[str, Any]]:
|
||||
"""Build servers configuration for MultiServerMCPClient.
|
||||
|
||||
Args:
|
||||
extensions_config: Extensions configuration containing all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping server names to their parameters.
|
||||
"""
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
if not enabled_servers:
|
||||
logger.info("No enabled MCP servers found")
|
||||
return {}
|
||||
|
||||
servers_config = {}
|
||||
for server_name, server_config in enabled_servers.items():
|
||||
try:
|
||||
servers_config[server_name] = build_server_params(server_name, server_config)
|
||||
logger.info(f"Configured MCP server: {server_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure MCP server '{server_name}': {e}")
|
||||
|
||||
return servers_config
|
||||
150
backend/packages/harness/deerflow/mcp/oauth.py
Normal file
150
backend/packages/harness/deerflow/mcp/oauth.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OAuth token support for MCP HTTP/SSE servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpOAuthConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OAuthToken:
|
||||
"""Cached OAuth token."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class OAuthTokenManager:
|
||||
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
|
||||
|
||||
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
|
||||
self._oauth_by_server = oauth_by_server
|
||||
self._tokens: dict[str, _OAuthToken] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
|
||||
|
||||
@classmethod
|
||||
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
|
||||
oauth_by_server: dict[str, McpOAuthConfig] = {}
|
||||
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
|
||||
if server_config.oauth and server_config.oauth.enabled:
|
||||
oauth_by_server[server_name] = server_config.oauth
|
||||
return cls(oauth_by_server)
|
||||
|
||||
def has_oauth_servers(self) -> bool:
|
||||
return bool(self._oauth_by_server)
|
||||
|
||||
def oauth_server_names(self) -> list[str]:
|
||||
return list(self._oauth_by_server.keys())
|
||||
|
||||
async def get_authorization_header(self, server_name: str) -> str | None:
|
||||
oauth = self._oauth_by_server.get(server_name)
|
||||
if not oauth:
|
||||
return None
|
||||
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
lock = self._locks[server_name]
|
||||
async with lock:
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
fresh = await self._fetch_token(oauth)
|
||||
self._tokens[server_name] = fresh
|
||||
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
|
||||
return f"{fresh.token_type} {fresh.access_token}"
|
||||
|
||||
@staticmethod
|
||||
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
|
||||
now = datetime.now(UTC)
|
||||
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
|
||||
|
||||
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
|
||||
import httpx # pyright: ignore[reportMissingImports]
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": oauth.grant_type,
|
||||
**oauth.extra_token_params,
|
||||
}
|
||||
|
||||
if oauth.scope:
|
||||
data["scope"] = oauth.scope
|
||||
if oauth.audience:
|
||||
data["audience"] = oauth.audience
|
||||
|
||||
if oauth.grant_type == "client_credentials":
|
||||
if not oauth.client_id or not oauth.client_secret:
|
||||
raise ValueError("OAuth client_credentials requires client_id and client_secret")
|
||||
data["client_id"] = oauth.client_id
|
||||
data["client_secret"] = oauth.client_secret
|
||||
elif oauth.grant_type == "refresh_token":
|
||||
if not oauth.refresh_token:
|
||||
raise ValueError("OAuth refresh_token grant requires refresh_token")
|
||||
data["refresh_token"] = oauth.refresh_token
|
||||
if oauth.client_id:
|
||||
data["client_id"] = oauth.client_id
|
||||
if oauth.client_secret:
|
||||
data["client_secret"] = oauth.client_secret
|
||||
else:
|
||||
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.post(oauth.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
access_token = payload.get(oauth.token_field)
|
||||
if not access_token:
|
||||
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
|
||||
|
||||
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
|
||||
|
||||
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
|
||||
try:
|
||||
expires_in = int(expires_in_raw)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = 3600
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
|
||||
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
|
||||
|
||||
|
||||
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
|
||||
"""Build a tool interceptor that injects OAuth Authorization headers."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return None
|
||||
|
||||
async def oauth_interceptor(request: Any, handler: Any) -> Any:
|
||||
header = await token_manager.get_authorization_header(request.server_name)
|
||||
if not header:
|
||||
return await handler(request)
|
||||
|
||||
updated_headers = dict(request.headers or {})
|
||||
updated_headers["Authorization"] = header
|
||||
return await handler(request.override(headers=updated_headers))
|
||||
|
||||
return oauth_interceptor
|
||||
|
||||
|
||||
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
|
||||
"""Get initial OAuth Authorization headers for MCP server connections."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return {}
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
for server_name in token_manager.oauth_server_names():
|
||||
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
|
||||
|
||||
return {name: value for name, value in headers.items() if value}
|
||||
66
backend/packages/harness/deerflow/mcp/tools.py
Normal file
66
backend/packages/harness/deerflow/mcp/tools.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
try:
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
except ImportError:
|
||||
logger.warning("langchain-mcp-adapters not installed. Install it to enable MCP tools: pip install langchain-mcp-adapters")
|
||||
return []
|
||||
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when initializing MCP tools.
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
servers_config = build_servers_config(extensions_config)
|
||||
|
||||
if not servers_config:
|
||||
logger.info("No enabled MCP servers configured")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Create the multi-server MCP client
|
||||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||||
|
||||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||||
for server_name, auth_header in initial_oauth_headers.items():
|
||||
if server_name not in servers_config:
|
||||
continue
|
||||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors)
|
||||
|
||||
# Get all tools from all servers
|
||||
tools = await client.get_tools()
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||||
return []
|
||||
3
backend/packages/harness/deerflow/models/__init__.py
Normal file
3
backend/packages/harness/deerflow/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .factory import create_chat_model
|
||||
|
||||
__all__ = ["create_chat_model"]
|
||||
79
backend/packages/harness/deerflow/models/factory.py
Normal file
79
backend/packages/harness/deerflow/models/factory.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
|
||||
from langchain.chat_models import BaseChatModel
|
||||
|
||||
from deerflow.config import get_app_config, get_tracing_config, is_tracing_enabled
|
||||
from deerflow.reflection import resolve_class
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
|
||||
"""Create a chat model instance from the config.
|
||||
|
||||
Args:
|
||||
name: The name of the model to create. If None, the first model in the config will be used.
|
||||
|
||||
Returns:
|
||||
A chat model instance.
|
||||
"""
|
||||
config = get_app_config()
|
||||
if name is None:
|
||||
name = config.models[0].name
|
||||
model_config = config.get_model_config(name)
|
||||
if model_config is None:
|
||||
raise ValueError(f"Model {name} not found in config") from None
|
||||
model_class = resolve_class(model_config.use, BaseChatModel)
|
||||
model_settings_from_config = model_config.model_dump(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
"use",
|
||||
"name",
|
||||
"display_name",
|
||||
"description",
|
||||
"supports_thinking",
|
||||
"supports_reasoning_effort",
|
||||
"when_thinking_enabled",
|
||||
"thinking",
|
||||
"supports_vision",
|
||||
},
|
||||
)
|
||||
# Compute effective when_thinking_enabled by merging in the `thinking` shortcut field.
|
||||
# The `thinking` shortcut is equivalent to setting when_thinking_enabled["thinking"].
|
||||
has_thinking_settings = (model_config.when_thinking_enabled is not None) or (model_config.thinking is not None)
|
||||
effective_wte: dict = dict(model_config.when_thinking_enabled) if model_config.when_thinking_enabled else {}
|
||||
if model_config.thinking is not None:
|
||||
merged_thinking = {**(effective_wte.get("thinking") or {}), **model_config.thinking}
|
||||
effective_wte = {**effective_wte, "thinking": merged_thinking}
|
||||
if thinking_enabled and has_thinking_settings:
|
||||
if not model_config.supports_thinking:
|
||||
raise ValueError(f"Model {name} does not support thinking. Set `supports_thinking` to true in the `config.yaml` to enable thinking.") from None
|
||||
if effective_wte:
|
||||
model_settings_from_config.update(effective_wte)
|
||||
if not thinking_enabled and has_thinking_settings:
|
||||
if effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
|
||||
# OpenAI-compatible gateway: thinking is nested under extra_body
|
||||
kwargs.update({"extra_body": {"thinking": {"type": "disabled"}}})
|
||||
kwargs.update({"reasoning_effort": "minimal"})
|
||||
elif effective_wte.get("thinking", {}).get("type"):
|
||||
# Native langchain_anthropic: thinking is a direct constructor parameter
|
||||
kwargs.update({"thinking": {"type": "disabled"}})
|
||||
if not model_config.supports_reasoning_effort and "reasoning_effort" in kwargs:
|
||||
del kwargs["reasoning_effort"]
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
if is_tracing_enabled():
|
||||
try:
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
tracing_config = get_tracing_config()
|
||||
tracer = LangChainTracer(
|
||||
project_name=tracing_config.project,
|
||||
)
|
||||
existing_callbacks = model_instance.callbacks or []
|
||||
model_instance.callbacks = [*existing_callbacks, tracer]
|
||||
logger.debug(f"LangSmith tracing attached to model '{name}' (project='{tracing_config.project}')")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to attach LangSmith tracing to model '{name}': {e}")
|
||||
return model_instance
|
||||
65
backend/packages/harness/deerflow/models/patched_deepseek.py
Normal file
65
backend/packages/harness/deerflow/models/patched_deepseek.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Patched ChatDeepSeek that preserves reasoning_content in multi-turn conversations.
|
||||
|
||||
This module provides a patched version of ChatDeepSeek that properly handles
|
||||
reasoning_content when sending messages back to the API. The original implementation
|
||||
stores reasoning_content in additional_kwargs but doesn't include it when making
|
||||
subsequent API calls, which causes errors with APIs that require reasoning_content
|
||||
on all assistant messages when thinking mode is enabled.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
|
||||
class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
"""ChatDeepSeek with proper reasoning_content preservation.
|
||||
|
||||
When using thinking/reasoning enabled models, the API expects reasoning_content
|
||||
to be present on ALL assistant messages in multi-turn conversations. This patched
|
||||
version ensures reasoning_content from additional_kwargs is included in the
|
||||
request payload.
|
||||
"""
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Get request payload with reasoning_content preserved.
|
||||
|
||||
Overrides the parent method to inject reasoning_content from
|
||||
additional_kwargs into assistant messages in the payload.
|
||||
"""
|
||||
# Get the original messages before conversion
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
|
||||
# Call parent to get the base payload
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
# Match payload messages with original messages to restore reasoning_content
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
# The payload messages and original messages should be in the same order
|
||||
# Iterate through both and match by position
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_msg["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
# Fallback: match by counting assistant messages
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
|
||||
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_messages[idx]["reasoning_content"] = reasoning_content
|
||||
|
||||
return payload
|
||||
3
backend/packages/harness/deerflow/reflection/__init__.py
Normal file
3
backend/packages/harness/deerflow/reflection/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .resolvers import resolve_class, resolve_variable
|
||||
|
||||
__all__ = ["resolve_class", "resolve_variable"]
|
||||
95
backend/packages/harness/deerflow/reflection/resolvers.py
Normal file
95
backend/packages/harness/deerflow/reflection/resolvers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from importlib import import_module
|
||||
|
||||
MODULE_TO_PACKAGE_HINTS = {
|
||||
"langchain_google_genai": "langchain-google-genai",
|
||||
"langchain_anthropic": "langchain-anthropic",
|
||||
"langchain_openai": "langchain-openai",
|
||||
"langchain_deepseek": "langchain-deepseek",
|
||||
}
|
||||
|
||||
|
||||
def _build_missing_dependency_hint(module_path: str, err: ImportError) -> str:
|
||||
"""Build an actionable hint when module import fails."""
|
||||
module_root = module_path.split(".", 1)[0]
|
||||
missing_module = getattr(err, "name", None) or module_root
|
||||
|
||||
# Prefer provider package hints for known integrations, even when the import
|
||||
# error is triggered by a transitive dependency (e.g. `google`).
|
||||
package_name = MODULE_TO_PACKAGE_HINTS.get(module_root)
|
||||
if package_name is None:
|
||||
package_name = MODULE_TO_PACKAGE_HINTS.get(missing_module, missing_module.replace("_", "-"))
|
||||
|
||||
return f"Missing dependency '{missing_module}'. Install it with `uv add {package_name}` (or `pip install {package_name}`), then restart DeerFlow."
|
||||
|
||||
|
||||
def resolve_variable[T](
|
||||
variable_path: str,
|
||||
expected_type: type[T] | tuple[type, ...] | None = None,
|
||||
) -> T:
|
||||
"""Resolve a variable from a path.
|
||||
|
||||
Args:
|
||||
variable_path: The path to the variable (e.g. "parent_package_name.sub_package_name.module_name:variable_name").
|
||||
expected_type: Optional type or tuple of types to validate the resolved variable against.
|
||||
If provided, uses isinstance() to check if the variable is an instance of the expected type(s).
|
||||
|
||||
Returns:
|
||||
The resolved variable.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module path is invalid or the attribute doesn't exist.
|
||||
ValueError: If the resolved variable doesn't pass the validation checks.
|
||||
"""
|
||||
try:
|
||||
module_path, variable_name = variable_path.rsplit(":", 1)
|
||||
except ValueError as err:
|
||||
raise ImportError(f"{variable_path} doesn't look like a variable path. Example: parent_package_name.sub_package_name.module_name:variable_name") from err
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
except ImportError as err:
|
||||
module_root = module_path.split(".", 1)[0]
|
||||
err_name = getattr(err, "name", None)
|
||||
if isinstance(err, ModuleNotFoundError) or err_name == module_root:
|
||||
hint = _build_missing_dependency_hint(module_path, err)
|
||||
raise ImportError(f"Could not import module {module_path}. {hint}") from err
|
||||
# Preserve the original ImportError message for non-missing-module failures.
|
||||
raise ImportError(f"Error importing module {module_path}: {err}") from err
|
||||
|
||||
try:
|
||||
variable = getattr(module, variable_name)
|
||||
except AttributeError as err:
|
||||
raise ImportError(f"Module {module_path} does not define a {variable_name} attribute/class") from err
|
||||
|
||||
# Type validation
|
||||
if expected_type is not None:
|
||||
if not isinstance(variable, expected_type):
|
||||
type_name = expected_type.__name__ if isinstance(expected_type, type) else " or ".join(t.__name__ for t in expected_type)
|
||||
raise ValueError(f"{variable_path} is not an instance of {type_name}, got {type(variable).__name__}")
|
||||
|
||||
return variable
|
||||
|
||||
|
||||
def resolve_class[T](class_path: str, base_class: type[T] | None = None) -> type[T]:
|
||||
"""Resolve a class from a module path and class name.
|
||||
|
||||
Args:
|
||||
class_path: The path to the class (e.g. "langchain_openai:ChatOpenAI").
|
||||
base_class: The base class to check if the resolved class is a subclass of.
|
||||
|
||||
Returns:
|
||||
The resolved class.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module path is invalid or the attribute doesn't exist.
|
||||
ValueError: If the resolved object is not a class or not a subclass of base_class.
|
||||
"""
|
||||
model_class = resolve_variable(class_path, expected_type=type)
|
||||
|
||||
if not isinstance(model_class, type):
|
||||
raise ValueError(f"{class_path} is not a valid class")
|
||||
|
||||
if base_class is not None and not issubclass(model_class, base_class):
|
||||
raise ValueError(f"{class_path} is not a subclass of {base_class.__name__}")
|
||||
|
||||
return model_class
|
||||
8
backend/packages/harness/deerflow/sandbox/__init__.py
Normal file
8
backend/packages/harness/deerflow/sandbox/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .sandbox import Sandbox
|
||||
from .sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||
|
||||
__all__ = [
|
||||
"Sandbox",
|
||||
"SandboxProvider",
|
||||
"get_sandbox_provider",
|
||||
]
|
||||
71
backend/packages/harness/deerflow/sandbox/exceptions.py
Normal file
71
backend/packages/harness/deerflow/sandbox/exceptions.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Sandbox-related exceptions with structured error information."""
|
||||
|
||||
|
||||
class SandboxError(Exception):
|
||||
"""Base exception for all sandbox-related errors."""
|
||||
|
||||
def __init__(self, message: str, details: dict | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.details:
|
||||
detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
|
||||
return f"{self.message} ({detail_str})"
|
||||
return self.message
|
||||
|
||||
|
||||
class SandboxNotFoundError(SandboxError):
|
||||
"""Raised when a sandbox cannot be found or is not available."""
|
||||
|
||||
def __init__(self, message: str = "Sandbox not found", sandbox_id: str | None = None):
|
||||
details = {"sandbox_id": sandbox_id} if sandbox_id else None
|
||||
super().__init__(message, details)
|
||||
self.sandbox_id = sandbox_id
|
||||
|
||||
|
||||
class SandboxRuntimeError(SandboxError):
|
||||
"""Raised when sandbox runtime is not available or misconfigured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxCommandError(SandboxError):
|
||||
"""Raised when a command execution fails in the sandbox."""
|
||||
|
||||
def __init__(self, message: str, command: str | None = None, exit_code: int | None = None):
|
||||
details = {}
|
||||
if command:
|
||||
details["command"] = command[:100] + "..." if len(command) > 100 else command
|
||||
if exit_code is not None:
|
||||
details["exit_code"] = exit_code
|
||||
super().__init__(message, details)
|
||||
self.command = command
|
||||
self.exit_code = exit_code
|
||||
|
||||
|
||||
class SandboxFileError(SandboxError):
|
||||
"""Raised when a file operation fails in the sandbox."""
|
||||
|
||||
def __init__(self, message: str, path: str | None = None, operation: str | None = None):
|
||||
details = {}
|
||||
if path:
|
||||
details["path"] = path
|
||||
if operation:
|
||||
details["operation"] = operation
|
||||
super().__init__(message, details)
|
||||
self.path = path
|
||||
self.operation = operation
|
||||
|
||||
|
||||
class SandboxPermissionError(SandboxFileError):
|
||||
"""Raised when a permission error occurs during file operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxFileNotFoundError(SandboxFileError):
|
||||
"""Raised when a file or directory is not found."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,3 @@
|
||||
from .local_sandbox_provider import LocalSandboxProvider
|
||||
|
||||
__all__ = ["LocalSandboxProvider"]
|
||||
112
backend/packages/harness/deerflow/sandbox/local/list_dir.py
Normal file
112
backend/packages/harness/deerflow/sandbox/local/list_dir.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import fnmatch
|
||||
from pathlib import Path
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
# Version Control
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
# Dependencies
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
# Build outputs
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
# IDE & Editor
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
# OS generated
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
# Logs & temp files
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
# Coverage & test artifacts
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
|
||||
def _should_ignore(name: str) -> bool:
|
||||
"""Check if a file/directory name matches any ignore pattern."""
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
"""
|
||||
List files and directories up to max_depth levels deep.
|
||||
|
||||
Args:
|
||||
path: The root directory path to list.
|
||||
max_depth: Maximum depth to traverse (default: 2).
|
||||
1 = only direct children, 2 = children + grandchildren, etc.
|
||||
|
||||
Returns:
|
||||
A list of absolute paths for files and directories,
|
||||
excluding items matching IGNORE_PATTERNS.
|
||||
"""
|
||||
result: list[str] = []
|
||||
root_path = Path(path).resolve()
|
||||
|
||||
if not root_path.is_dir():
|
||||
return result
|
||||
|
||||
def _traverse(current_path: Path, current_depth: int) -> None:
|
||||
"""Recursively traverse directories up to max_depth."""
|
||||
if current_depth > max_depth:
|
||||
return
|
||||
|
||||
try:
|
||||
for item in current_path.iterdir():
|
||||
if _should_ignore(item.name):
|
||||
continue
|
||||
|
||||
post_fix = "/" if item.is_dir() else ""
|
||||
result.append(str(item.resolve()) + post_fix)
|
||||
|
||||
# Recurse into subdirectories if not at max depth
|
||||
if item.is_dir() and current_depth < max_depth:
|
||||
_traverse(item, current_depth + 1)
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
_traverse(root_path, 1)
|
||||
|
||||
return sorted(result)
|
||||
212
backend/packages/harness/deerflow/sandbox/local/local_sandbox.py
Normal file
212
backend/packages/harness/deerflow/sandbox/local/local_sandbox.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.list_dir import list_dir
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
|
||||
|
||||
class LocalSandbox(Sandbox):
|
||||
def __init__(self, id: str, path_mappings: dict[str, str] | None = None):
|
||||
"""
|
||||
Initialize local sandbox with optional path mappings.
|
||||
|
||||
Args:
|
||||
id: Sandbox identifier
|
||||
path_mappings: Dictionary mapping container paths to local paths
|
||||
Example: {"/mnt/skills": "/absolute/path/to/skills"}
|
||||
"""
|
||||
super().__init__(id)
|
||||
self.path_mappings = path_mappings or {}
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve container path to actual local path using mappings.
|
||||
|
||||
Args:
|
||||
path: Path that might be a container path
|
||||
|
||||
Returns:
|
||||
Resolved local path
|
||||
"""
|
||||
path_str = str(path)
|
||||
|
||||
# Try each mapping (longest prefix first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
if path_str.startswith(container_path):
|
||||
# Replace the container path prefix with local path
|
||||
relative = path_str[len(container_path) :].lstrip("/")
|
||||
resolved = str(Path(local_path) / relative) if relative else local_path
|
||||
return resolved
|
||||
|
||||
# No mapping found, return original path
|
||||
return path_str
|
||||
|
||||
def _reverse_resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Reverse resolve local path back to container path using mappings.
|
||||
|
||||
Args:
|
||||
path: Local path that might need to be mapped to container path
|
||||
|
||||
Returns:
|
||||
Container path if mapping exists, otherwise original path
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
|
||||
# Try each mapping (longest local path first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True):
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
if path_str.startswith(local_path_resolved):
|
||||
# Replace the local path prefix with container path
|
||||
relative = path_str[len(local_path_resolved) :].lstrip("/")
|
||||
resolved = f"{container_path}/{relative}" if relative else container_path
|
||||
return resolved
|
||||
|
||||
# No mapping found, return original path
|
||||
return path_str
|
||||
|
||||
def _reverse_resolve_paths_in_output(self, output: str) -> str:
|
||||
"""
|
||||
Reverse resolve local paths back to container paths in output string.
|
||||
|
||||
Args:
|
||||
output: Output string that may contain local paths
|
||||
|
||||
Returns:
|
||||
Output with local paths resolved to container paths
|
||||
"""
|
||||
import re
|
||||
|
||||
# Sort mappings by local path length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
|
||||
if not sorted_mappings:
|
||||
return output
|
||||
|
||||
# Create pattern that matches absolute paths
|
||||
# Match paths like /Users/... or other absolute paths
|
||||
result = output
|
||||
for container_path, local_path in sorted_mappings:
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
# Escape the local path for use in regex
|
||||
escaped_local = re.escape(local_path_resolved)
|
||||
# Match the local path followed by optional path components
|
||||
pattern = re.compile(escaped_local + r"(?:/[^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
return self._reverse_resolve_path(matched_path)
|
||||
|
||||
result = pattern.sub(replace_match, result)
|
||||
|
||||
return result
|
||||
|
||||
def _resolve_paths_in_command(self, command: str) -> str:
|
||||
"""
|
||||
Resolve container paths to local paths in a command string.
|
||||
|
||||
Args:
|
||||
command: Command string that may contain container paths
|
||||
|
||||
Returns:
|
||||
Command with container paths resolved to local paths
|
||||
"""
|
||||
import re
|
||||
|
||||
# Sort mappings by length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True)
|
||||
|
||||
# Build regex pattern to match all container paths
|
||||
# Match container path followed by optional path components
|
||||
if not sorted_mappings:
|
||||
return command
|
||||
|
||||
# Create pattern that matches any of the container paths
|
||||
patterns = [re.escape(container_path) + r"(?:/[^\s\"';&|<>()]*)??" for container_path, _ in sorted_mappings]
|
||||
pattern = re.compile("|".join(f"({p})" for p in patterns))
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
return self._resolve_path(matched_path)
|
||||
|
||||
return pattern.sub(replace_match, command)
|
||||
|
||||
@staticmethod
|
||||
def _get_shell() -> str:
|
||||
"""Detect available shell executable with fallback.
|
||||
|
||||
Returns the first available shell in order of preference:
|
||||
/bin/zsh → /bin/bash → /bin/sh → first `sh` found on PATH.
|
||||
Raises a RuntimeError if no suitable shell is found.
|
||||
"""
|
||||
for shell in ("/bin/zsh", "/bin/bash", "/bin/sh"):
|
||||
if os.path.isfile(shell) and os.access(shell, os.X_OK):
|
||||
return shell
|
||||
shell_from_path = shutil.which("sh")
|
||||
if shell_from_path is not None:
|
||||
return shell_from_path
|
||||
raise RuntimeError("No suitable shell executable found. Tried /bin/zsh, /bin/bash, /bin/sh, and `sh` on PATH.")
|
||||
|
||||
def execute_command(self, command: str) -> str:
|
||||
# Resolve container paths in command before execution
|
||||
resolved_command = self._resolve_paths_in_command(command)
|
||||
|
||||
result = subprocess.run(
|
||||
resolved_command,
|
||||
executable=self._get_shell(),
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
)
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += f"\nStd Error:\n{result.stderr}" if output else result.stderr
|
||||
if result.returncode != 0:
|
||||
output += f"\nExit Code: {result.returncode}"
|
||||
|
||||
final_output = output if output else "(no output)"
|
||||
# Reverse resolve local paths back to container paths in output
|
||||
return self._reverse_resolve_paths_in_output(final_output)
|
||||
|
||||
def list_dir(self, path: str, max_depth=2) -> list[str]:
|
||||
resolved_path = self._resolve_path(path)
|
||||
entries = list_dir(resolved_path, max_depth)
|
||||
# Reverse resolve local paths back to container paths in output
|
||||
return [self._reverse_resolve_paths_in_output(entry) for entry in entries]
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
resolved_path = self._resolve_path(path)
|
||||
try:
|
||||
with open(resolved_path) as f:
|
||||
return f.read()
|
||||
except OSError as e:
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
mode = "a" if append else "w"
|
||||
with open(resolved_path, mode) as f:
|
||||
f.write(content)
|
||||
except OSError as e:
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(resolved_path, "wb") as f:
|
||||
f.write(content)
|
||||
except OSError as e:
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
@@ -0,0 +1,60 @@
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
_singleton: LocalSandbox | None = None
|
||||
|
||||
|
||||
class LocalSandboxProvider(SandboxProvider):
|
||||
def __init__(self):
|
||||
"""Initialize the local sandbox provider with path mappings."""
|
||||
self._path_mappings = self._setup_path_mappings()
|
||||
|
||||
def _setup_path_mappings(self) -> dict[str, str]:
|
||||
"""
|
||||
Setup path mappings for local sandbox.
|
||||
|
||||
Maps container paths to actual local paths, including skills directory.
|
||||
|
||||
Returns:
|
||||
Dictionary of path mappings
|
||||
"""
|
||||
mappings = {}
|
||||
|
||||
# Map skills container path to local skills directory
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
container_path = config.skills.container_path
|
||||
|
||||
# Only add mapping if skills directory exists
|
||||
if skills_path.exists():
|
||||
mappings[container_path] = str(skills_path)
|
||||
except Exception as e:
|
||||
# Log but don't fail if config loading fails
|
||||
print(f"Warning: Could not setup skills path mapping: {e}")
|
||||
|
||||
return mappings
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
global _singleton
|
||||
if _singleton is None:
|
||||
_singleton = LocalSandbox("local", path_mappings=self._path_mappings)
|
||||
return _singleton.id
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
if sandbox_id == "local":
|
||||
if _singleton is None:
|
||||
self.acquire()
|
||||
return _singleton
|
||||
return None
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
# LocalSandbox uses singleton pattern - no cleanup needed.
|
||||
# Note: This method is intentionally not called by SandboxMiddleware
|
||||
# to allow sandbox reuse across multiple turns in a thread.
|
||||
# For Docker-based providers (e.g., AioSandboxProvider), cleanup
|
||||
# happens at application shutdown via the shutdown() method.
|
||||
pass
|
||||
81
backend/packages/harness/deerflow/sandbox/middleware.py
Normal file
81
backend/packages/harness/deerflow/sandbox/middleware.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState
|
||||
from deerflow.sandbox import get_sandbox_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
"""Create a sandbox environment and assign it to an agent.
|
||||
|
||||
Lifecycle Management:
|
||||
- With lazy_init=True (default): Sandbox is acquired on first tool call
|
||||
- With lazy_init=False: Sandbox is acquired on first agent invocation (before_agent)
|
||||
- Sandbox is reused across multiple turns within the same thread
|
||||
- Sandbox is NOT released after each agent call to avoid wasteful recreation
|
||||
- Cleanup happens at application shutdown via SandboxProvider.shutdown()
|
||||
"""
|
||||
|
||||
state_schema = SandboxMiddlewareState
|
||||
|
||||
def __init__(self, lazy_init: bool = True):
|
||||
"""Initialize sandbox middleware.
|
||||
|
||||
Args:
|
||||
lazy_init: If True, defer sandbox acquisition until first tool call.
|
||||
If False, acquire sandbox eagerly in before_agent().
|
||||
Default is True for optimal performance.
|
||||
"""
|
||||
super().__init__()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _acquire_sandbox(self, thread_id: str) -> str:
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
logger.info(f"Acquiring sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
|
||||
@override
|
||||
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Skip acquisition if lazy_init is enabled
|
||||
if self._lazy_init:
|
||||
return super().before_agent(state, runtime)
|
||||
|
||||
# Eager initialization (original behavior)
|
||||
if "sandbox" not in state or state["sandbox"] is None:
|
||||
thread_id = runtime.context["thread_id"]
|
||||
sandbox_id = self._acquire_sandbox(thread_id)
|
||||
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
|
||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
return super().before_agent(state, runtime)
|
||||
|
||||
@override
|
||||
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
sandbox = state.get("sandbox")
|
||||
if sandbox is not None:
|
||||
sandbox_id = sandbox["sandbox_id"]
|
||||
logger.info(f"Releasing sandbox {sandbox_id}")
|
||||
get_sandbox_provider().release(sandbox_id)
|
||||
return None
|
||||
|
||||
if runtime.context.get("sandbox_id") is not None:
|
||||
sandbox_id = runtime.context.get("sandbox_id")
|
||||
logger.info(f"Releasing sandbox {sandbox_id} from context")
|
||||
get_sandbox_provider().release(sandbox_id)
|
||||
return None
|
||||
|
||||
# No sandbox to release
|
||||
return super().after_agent(state, runtime)
|
||||
72
backend/packages/harness/deerflow/sandbox/sandbox.py
Normal file
72
backend/packages/harness/deerflow/sandbox/sandbox.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Sandbox(ABC):
|
||||
"""Abstract base class for sandbox environments"""
|
||||
|
||||
_id: str
|
||||
|
||||
def __init__(self, id: str):
|
||||
self._id = id
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(self, command: str) -> str:
|
||||
"""Execute bash command in sandbox.
|
||||
|
||||
Args:
|
||||
command: The command to execute.
|
||||
|
||||
Returns:
|
||||
The standard or error output of the command.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, path: str) -> str:
|
||||
"""Read the content of a file.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to read.
|
||||
|
||||
Returns:
|
||||
The content of the file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_dir(self, path: str, max_depth=2) -> list[str]:
|
||||
"""List the contents of a directory.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the directory to list.
|
||||
max_depth: The maximum depth to traverse. Default is 2.
|
||||
|
||||
Returns:
|
||||
The contents of the directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
"""Write content to a file.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to write to.
|
||||
content: The text content to write to the file.
|
||||
append: Whether to append the content to the file. If False, the file will be created or overwritten.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to update.
|
||||
content: The binary content to write to the file.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,96 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.reflection import resolve_class
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
|
||||
|
||||
class SandboxProvider(ABC):
|
||||
"""Abstract base class for sandbox providers"""
|
||||
|
||||
@abstractmethod
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
Returns:
|
||||
The ID of the acquired sandbox environment.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox environment by ID.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox environment to retain.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox environment.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox environment to destroy.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
_default_sandbox_provider: SandboxProvider | None = None
|
||||
|
||||
|
||||
def get_sandbox_provider(**kwargs) -> SandboxProvider:
|
||||
"""Get the sandbox provider singleton.
|
||||
|
||||
Returns a cached singleton instance. Use `reset_sandbox_provider()` to clear
|
||||
the cache, or `shutdown_sandbox_provider()` to properly shutdown and clear.
|
||||
|
||||
Returns:
|
||||
A sandbox provider instance.
|
||||
"""
|
||||
global _default_sandbox_provider
|
||||
if _default_sandbox_provider is None:
|
||||
config = get_app_config()
|
||||
cls = resolve_class(config.sandbox.use, SandboxProvider)
|
||||
_default_sandbox_provider = cls(**kwargs)
|
||||
return _default_sandbox_provider
|
||||
|
||||
|
||||
def reset_sandbox_provider() -> None:
|
||||
"""Reset the sandbox provider singleton.
|
||||
|
||||
This clears the cached instance without calling shutdown.
|
||||
The next call to `get_sandbox_provider()` will create a new instance.
|
||||
Useful for testing or when switching configurations.
|
||||
|
||||
Note: If the provider has active sandboxes, they will be orphaned.
|
||||
Use `shutdown_sandbox_provider()` for proper cleanup.
|
||||
"""
|
||||
global _default_sandbox_provider
|
||||
_default_sandbox_provider = None
|
||||
|
||||
|
||||
def shutdown_sandbox_provider() -> None:
|
||||
"""Shutdown and reset the sandbox provider.
|
||||
|
||||
This properly shuts down the provider (releasing all sandboxes)
|
||||
before clearing the singleton. Call this when the application
|
||||
is shutting down or when you need to completely reset the sandbox system.
|
||||
"""
|
||||
global _default_sandbox_provider
|
||||
if _default_sandbox_provider is not None:
|
||||
if hasattr(_default_sandbox_provider, "shutdown"):
|
||||
_default_sandbox_provider.shutdown()
|
||||
_default_sandbox_provider = None
|
||||
|
||||
|
||||
def set_sandbox_provider(provider: SandboxProvider) -> None:
|
||||
"""Set a custom sandbox provider instance.
|
||||
|
||||
This allows injecting a custom or mock provider for testing purposes.
|
||||
|
||||
Args:
|
||||
provider: The SandboxProvider instance to use.
|
||||
"""
|
||||
global _default_sandbox_provider
|
||||
_default_sandbox_provider = provider
|
||||
538
backend/packages/harness/deerflow/sandbox/tools.py
Normal file
538
backend/packages/harness/deerflow/sandbox/tools.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
SandboxError,
|
||||
SandboxNotFoundError,
|
||||
SandboxRuntimeError,
|
||||
)
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])/(?:[^\s\"'`;&|<>()]+)")
|
||||
_LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
|
||||
"/bin/",
|
||||
"/usr/bin/",
|
||||
"/usr/sbin/",
|
||||
"/sbin/",
|
||||
"/opt/homebrew/bin/",
|
||||
"/dev/",
|
||||
)
|
||||
|
||||
|
||||
def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
"""Replace virtual /mnt/user-data paths with actual thread data paths.
|
||||
|
||||
Mapping:
|
||||
/mnt/user-data/workspace/* -> thread_data['workspace_path']/*
|
||||
/mnt/user-data/uploads/* -> thread_data['uploads_path']/*
|
||||
/mnt/user-data/outputs/* -> thread_data['outputs_path']/*
|
||||
|
||||
Args:
|
||||
path: The path that may contain virtual path prefix.
|
||||
thread_data: The thread data containing actual paths.
|
||||
|
||||
Returns:
|
||||
The path with virtual prefix replaced by actual path.
|
||||
"""
|
||||
if thread_data is None:
|
||||
return path
|
||||
|
||||
mappings = _thread_virtual_to_actual_mappings(thread_data)
|
||||
if not mappings:
|
||||
return path
|
||||
|
||||
# Longest-prefix-first replacement with segment-boundary checks.
|
||||
for virtual_base, actual_base in sorted(mappings.items(), key=lambda item: len(item[0]), reverse=True):
|
||||
if path == virtual_base:
|
||||
return actual_base
|
||||
if path.startswith(f"{virtual_base}/"):
|
||||
rest = path[len(virtual_base) :].lstrip("/")
|
||||
return str(Path(actual_base) / rest) if rest else actual_base
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def _thread_virtual_to_actual_mappings(thread_data: ThreadDataState) -> dict[str, str]:
|
||||
"""Build virtual-to-actual path mappings for a thread."""
|
||||
mappings: dict[str, str] = {}
|
||||
|
||||
workspace = thread_data.get("workspace_path")
|
||||
uploads = thread_data.get("uploads_path")
|
||||
outputs = thread_data.get("outputs_path")
|
||||
|
||||
if workspace:
|
||||
mappings[f"{VIRTUAL_PATH_PREFIX}/workspace"] = workspace
|
||||
if uploads:
|
||||
mappings[f"{VIRTUAL_PATH_PREFIX}/uploads"] = uploads
|
||||
if outputs:
|
||||
mappings[f"{VIRTUAL_PATH_PREFIX}/outputs"] = outputs
|
||||
|
||||
# Also map the virtual root when all known dirs share the same parent.
|
||||
actual_dirs = [Path(p) for p in (workspace, uploads, outputs) if p]
|
||||
if actual_dirs:
|
||||
common_parent = str(Path(actual_dirs[0]).parent)
|
||||
if all(str(path.parent) == common_parent for path in actual_dirs):
|
||||
mappings[VIRTUAL_PATH_PREFIX] = common_parent
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
def _thread_actual_to_virtual_mappings(thread_data: ThreadDataState) -> dict[str, str]:
|
||||
"""Build actual-to-virtual mappings for output masking."""
|
||||
return {actual: virtual for virtual, actual in _thread_virtual_to_actual_mappings(thread_data).items()}
|
||||
|
||||
|
||||
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None) -> str:
|
||||
"""Mask host absolute paths from local sandbox output using virtual paths."""
|
||||
if thread_data is None:
|
||||
return output
|
||||
|
||||
mappings = _thread_actual_to_virtual_mappings(thread_data)
|
||||
if not mappings:
|
||||
return output
|
||||
|
||||
result = output
|
||||
for actual_base, virtual_base in sorted(mappings.items(), key=lambda item: len(item[0]), reverse=True):
|
||||
raw_base = str(Path(actual_base))
|
||||
resolved_base = str(Path(actual_base).resolve())
|
||||
for base in {raw_base, resolved_base}:
|
||||
escaped_actual = re.escape(base)
|
||||
pattern = re.compile(escaped_actual + r"(?:/[^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
if matched_path == base:
|
||||
return virtual_base
|
||||
relative = matched_path[len(base) :].lstrip("/")
|
||||
return f"{virtual_base}/{relative}" if relative else virtual_base
|
||||
|
||||
result = pattern.sub(replace_match, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def resolve_local_tool_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
"""Resolve and validate a local-sandbox tool path.
|
||||
|
||||
Only virtual paths under /mnt/user-data are allowed in local mode.
|
||||
"""
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
|
||||
if not path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/ are allowed")
|
||||
|
||||
resolved_path = replace_virtual_path(path, thread_data)
|
||||
resolved = Path(resolved_path).resolve()
|
||||
|
||||
allowed_roots = [
|
||||
Path(p).resolve()
|
||||
for p in (
|
||||
thread_data.get("workspace_path"),
|
||||
thread_data.get("uploads_path"),
|
||||
thread_data.get("outputs_path"),
|
||||
)
|
||||
if p is not None
|
||||
]
|
||||
|
||||
if not allowed_roots:
|
||||
raise SandboxRuntimeError("No allowed local sandbox directories configured")
|
||||
|
||||
for root in allowed_roots:
|
||||
try:
|
||||
resolved.relative_to(root)
|
||||
return str(resolved)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
|
||||
|
||||
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None:
|
||||
"""Validate absolute paths in local-sandbox bash commands.
|
||||
|
||||
In local mode, commands must use virtual paths under /mnt/user-data for
|
||||
user data access. A small allowlist of common system path prefixes is kept
|
||||
for executable and device references (e.g. /bin/sh, /dev/null).
|
||||
"""
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
|
||||
unsafe_paths: list[str] = []
|
||||
|
||||
for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command):
|
||||
if absolute_path == VIRTUAL_PATH_PREFIX or absolute_path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
|
||||
continue
|
||||
|
||||
if any(
|
||||
absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix)
|
||||
for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES
|
||||
):
|
||||
continue
|
||||
|
||||
unsafe_paths.append(absolute_path)
|
||||
|
||||
if unsafe_paths:
|
||||
unsafe = ", ".join(sorted(dict.fromkeys(unsafe_paths)))
|
||||
raise PermissionError(f"Unsafe absolute paths in command: {unsafe}. Use paths under {VIRTUAL_PATH_PREFIX}")
|
||||
|
||||
|
||||
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None) -> str:
|
||||
"""Replace all virtual /mnt/user-data paths in a command string.
|
||||
|
||||
Args:
|
||||
command: The command string that may contain virtual paths.
|
||||
thread_data: The thread data containing actual paths.
|
||||
|
||||
Returns:
|
||||
The command with all virtual paths replaced.
|
||||
"""
|
||||
if VIRTUAL_PATH_PREFIX not in command:
|
||||
return command
|
||||
|
||||
if thread_data is None:
|
||||
return command
|
||||
|
||||
# Pattern to match /mnt/user-data followed by path characters
|
||||
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
full_path = match.group(0)
|
||||
return replace_virtual_path(full_path, thread_data)
|
||||
|
||||
return pattern.sub(replace_match, command)
|
||||
|
||||
|
||||
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
|
||||
"""Extract thread_data from runtime state."""
|
||||
if runtime is None:
|
||||
return None
|
||||
if runtime.state is None:
|
||||
return None
|
||||
return runtime.state.get("thread_data")
|
||||
|
||||
|
||||
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
|
||||
"""Check if the current sandbox is a local sandbox.
|
||||
|
||||
Path replacement is only needed for local sandbox since aio sandbox
|
||||
already has /mnt/user-data mounted in the container.
|
||||
"""
|
||||
if runtime is None:
|
||||
return False
|
||||
if runtime.state is None:
|
||||
return False
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
if sandbox_state is None:
|
||||
return False
|
||||
return sandbox_state.get("sandbox_id") == "local"
|
||||
|
||||
|
||||
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
"""Extract sandbox instance from tool runtime.
|
||||
|
||||
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
|
||||
This function assumes sandbox is already initialized and will raise error if not.
|
||||
|
||||
Raises:
|
||||
SandboxRuntimeError: If runtime is not available or sandbox state is missing.
|
||||
SandboxNotFoundError: If sandbox with the given ID cannot be found.
|
||||
"""
|
||||
if runtime is None:
|
||||
raise SandboxRuntimeError("Tool runtime not available")
|
||||
if runtime.state is None:
|
||||
raise SandboxRuntimeError("Tool runtime state not available")
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
if sandbox_state is None:
|
||||
raise SandboxRuntimeError("Sandbox state not initialized in runtime")
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
if sandbox_id is None:
|
||||
raise SandboxRuntimeError("Sandbox ID not found in state")
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
"""Ensure sandbox is initialized, acquiring lazily if needed.
|
||||
|
||||
On first call, acquires a sandbox from the provider and stores it in runtime state.
|
||||
Subsequent calls return the existing sandbox.
|
||||
|
||||
Thread-safety is guaranteed by the provider's internal locking mechanism.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime containing state and context.
|
||||
|
||||
Returns:
|
||||
Initialized sandbox instance.
|
||||
|
||||
Raises:
|
||||
SandboxRuntimeError: If runtime is not available or thread_id is missing.
|
||||
SandboxNotFoundError: If sandbox acquisition fails.
|
||||
"""
|
||||
if runtime is None:
|
||||
raise SandboxRuntimeError("Tool runtime not available")
|
||||
|
||||
if runtime.state is None:
|
||||
raise SandboxRuntimeError("Tool runtime state not available")
|
||||
|
||||
# Check if sandbox already exists in state
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
if sandbox_state is not None:
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
# Sandbox was released, fall through to acquire new one
|
||||
|
||||
# Lazy acquisition: get thread_id and acquire sandbox
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise SandboxRuntimeError("Thread ID not available in runtime context")
|
||||
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
|
||||
# Update runtime state - this persists across tool calls
|
||||
runtime.state["sandbox"] = {"sandbox_id": sandbox_id}
|
||||
|
||||
# Retrieve and return the sandbox
|
||||
sandbox = provider.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None:
|
||||
"""Ensure thread data directories (workspace, uploads, outputs) exist.
|
||||
|
||||
This function is called lazily when any sandbox tool is first used.
|
||||
For local sandbox, it creates the directories on the filesystem.
|
||||
For other sandboxes (like aio), directories are already mounted in the container.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime containing state and context.
|
||||
"""
|
||||
if runtime is None:
|
||||
return
|
||||
|
||||
# Only create directories for local sandbox
|
||||
if not is_local_sandbox(runtime):
|
||||
return
|
||||
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
return
|
||||
|
||||
# Check if directories have already been created
|
||||
if runtime.state.get("thread_directories_created"):
|
||||
return
|
||||
|
||||
# Create the three directories
|
||||
import os
|
||||
|
||||
for key in ["workspace_path", "uploads_path", "outputs_path"]:
|
||||
path = thread_data.get(key)
|
||||
if path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# Mark as created to avoid redundant operations
|
||||
runtime.state["thread_directories_created"] = True
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
|
||||
|
||||
- Use `python` to run Python code.
|
||||
- Prefer a thread-local virtual environment in `/mnt/user-data/workspace/.venv`.
|
||||
- Use `python -m pip` (inside the virtual environment) to install Python packages.
|
||||
|
||||
Args:
|
||||
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
command: The bash command to execute. Always use absolute paths for files and directories.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
thread_data = get_thread_data(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
validate_local_bash_command_paths(command, thread_data)
|
||||
command = replace_virtual_paths_in_command(command, thread_data)
|
||||
output = sandbox.execute_command(command)
|
||||
return mask_local_paths_in_output(output, thread_data)
|
||||
return sandbox.execute_command(command)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error executing command: {type(e).__name__}: {e}"
|
||||
|
||||
|
||||
@tool("ls", parse_docstring=True)
|
||||
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
|
||||
"""List the contents of a directory up to 2 levels deep in tree format.
|
||||
|
||||
Args:
|
||||
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
path: The **absolute** path to the directory to list.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = resolve_local_tool_path(path, thread_data)
|
||||
children = sandbox.list_dir(path)
|
||||
if not children:
|
||||
return "(empty)"
|
||||
return "\n".join(children)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error listing directory: {type(e).__name__}: {e}"
|
||||
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
path: str,
|
||||
start_line: int | None = None,
|
||||
end_line: int | None = None,
|
||||
) -> str:
|
||||
"""Read the contents of a text file. Use this to examine source code, configuration files, logs, or any text-based file.
|
||||
|
||||
Args:
|
||||
description: Explain why you are reading this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
path: The **absolute** path to the file to read.
|
||||
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
|
||||
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = resolve_local_tool_path(path, thread_data)
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "(empty)"
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
return content
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: File not found: {requested_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied reading file: {requested_path}"
|
||||
except IsADirectoryError:
|
||||
return f"Error: Path is a directory, not a file: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error reading file: {type(e).__name__}: {e}"
|
||||
|
||||
|
||||
@tool("write_file", parse_docstring=True)
|
||||
def write_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
path: str,
|
||||
content: str,
|
||||
append: bool = False,
|
||||
) -> str:
|
||||
"""Write text content to a file.
|
||||
|
||||
Args:
|
||||
description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = resolve_local_tool_path(path, thread_data)
|
||||
sandbox.write_file(path, content, append)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied writing to file: {requested_path}"
|
||||
except IsADirectoryError:
|
||||
return f"Error: Path is a directory, not a file: {requested_path}"
|
||||
except OSError as e:
|
||||
return f"Error: Failed to write file '{requested_path}': {e}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error writing file: {type(e).__name__}: {e}"
|
||||
|
||||
|
||||
@tool("str_replace", parse_docstring=True)
|
||||
def str_replace_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
path: str,
|
||||
old_str: str,
|
||||
new_str: str,
|
||||
replace_all: bool = False,
|
||||
) -> str:
|
||||
"""Replace a substring in a file with another substring.
|
||||
If `replace_all` is False (default), the substring to replace must appear **exactly once** in the file.
|
||||
|
||||
Args:
|
||||
description: Explain why you are replacing the substring in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
path: The **absolute** path to the file to replace the substring in. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
old_str: The substring to replace. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.
|
||||
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = resolve_local_tool_path(path, thread_data)
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "OK"
|
||||
if old_str not in content:
|
||||
return f"Error: String to replace not found in file: {requested_path}"
|
||||
if replace_all:
|
||||
content = content.replace(old_str, new_str)
|
||||
else:
|
||||
content = content.replace(old_str, new_str, 1)
|
||||
sandbox.write_file(path, content)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: File not found: {requested_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied accessing file: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error replacing string: {type(e).__name__}: {e}"
|
||||
5
backend/packages/harness/deerflow/skills/__init__.py
Normal file
5
backend/packages/harness/deerflow/skills/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .loader import get_skills_root_path, load_skills
|
||||
from .types import Skill
|
||||
from .validation import ALLOWED_FRONTMATTER_PROPERTIES, _validate_skill_frontmatter
|
||||
|
||||
__all__ = ["load_skills", "get_skills_root_path", "Skill", "ALLOWED_FRONTMATTER_PROPERTIES", "_validate_skill_frontmatter"]
|
||||
98
backend/packages/harness/deerflow/skills/loader.py
Normal file
98
backend/packages/harness/deerflow/skills/loader.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .parser import parse_skill_file
|
||||
from .types import Skill
|
||||
|
||||
|
||||
def get_skills_root_path() -> Path:
|
||||
"""
|
||||
Get the root path of the skills directory.
|
||||
|
||||
Returns:
|
||||
Path to the skills directory (deer-flow/skills)
|
||||
"""
|
||||
# loader.py lives at packages/harness/deerflow/skills/loader.py — 5 parents up reaches backend/
|
||||
backend_dir = Path(__file__).resolve().parent.parent.parent.parent.parent
|
||||
# skills directory is sibling to backend directory
|
||||
skills_dir = backend_dir.parent / "skills"
|
||||
return skills_dir
|
||||
|
||||
|
||||
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
|
||||
"""
|
||||
Load all skills from the skills directory.
|
||||
|
||||
Scans both public and custom skill directories, parsing SKILL.md files
|
||||
to extract metadata. The enabled state is determined by the skills_state_config.json file.
|
||||
|
||||
Args:
|
||||
skills_path: Optional custom path to skills directory.
|
||||
If not provided and use_config is True, uses path from config.
|
||||
Otherwise defaults to deer-flow/skills
|
||||
use_config: Whether to load skills path from config (default: True)
|
||||
enabled_only: If True, only return enabled skills (default: False)
|
||||
|
||||
Returns:
|
||||
List of Skill objects, sorted by name
|
||||
"""
|
||||
if skills_path is None:
|
||||
if use_config:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
except Exception:
|
||||
# Fallback to default if config fails
|
||||
skills_path = get_skills_root_path()
|
||||
else:
|
||||
skills_path = get_skills_root_path()
|
||||
|
||||
if not skills_path.exists():
|
||||
return []
|
||||
|
||||
skills = []
|
||||
|
||||
# Scan public and custom directories
|
||||
for category in ["public", "custom"]:
|
||||
category_path = skills_path / category
|
||||
if not category_path.exists() or not category_path.is_dir():
|
||||
continue
|
||||
|
||||
for current_root, dir_names, file_names in os.walk(category_path):
|
||||
# Keep traversal deterministic and skip hidden directories.
|
||||
dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
|
||||
if "SKILL.md" not in file_names:
|
||||
continue
|
||||
|
||||
skill_file = Path(current_root) / "SKILL.md"
|
||||
relative_path = skill_file.parent.relative_to(category_path)
|
||||
|
||||
skill = parse_skill_file(skill_file, category=category, relative_path=relative_path)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
|
||||
# Load skills state configuration and update enabled status
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected in the LangGraph Server when loading skills.
|
||||
try:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
for skill in skills:
|
||||
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
|
||||
except Exception as e:
|
||||
# If config loading fails, default to all enabled
|
||||
print(f"Warning: Failed to load extensions config: {e}")
|
||||
|
||||
# Filter by enabled status if requested
|
||||
if enabled_only:
|
||||
skills = [skill for skill in skills if skill.enabled]
|
||||
|
||||
# Sort by name for consistent ordering
|
||||
skills.sort(key=lambda s: s.name)
|
||||
|
||||
return skills
|
||||
65
backend/packages/harness/deerflow/skills/parser.py
Normal file
65
backend/packages/harness/deerflow/skills/parser.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from .types import Skill
|
||||
|
||||
|
||||
def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None = None) -> Skill | None:
|
||||
"""
|
||||
Parse a SKILL.md file and extract metadata.
|
||||
|
||||
Args:
|
||||
skill_file: Path to the SKILL.md file
|
||||
category: Category of the skill ('public' or 'custom')
|
||||
|
||||
Returns:
|
||||
Skill object if parsing succeeds, None otherwise
|
||||
"""
|
||||
if not skill_file.exists() or skill_file.name != "SKILL.md":
|
||||
return None
|
||||
|
||||
try:
|
||||
content = skill_file.read_text(encoding="utf-8")
|
||||
|
||||
# Extract YAML front matter
|
||||
# Pattern: ---\nkey: value\n---
|
||||
front_matter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
|
||||
if not front_matter_match:
|
||||
return None
|
||||
|
||||
front_matter = front_matter_match.group(1)
|
||||
|
||||
# Parse YAML front matter (simple key-value parsing)
|
||||
metadata = {}
|
||||
for line in front_matter.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip()
|
||||
|
||||
# Extract required fields
|
||||
name = metadata.get("name")
|
||||
description = metadata.get("description")
|
||||
|
||||
if not name or not description:
|
||||
return None
|
||||
|
||||
license_text = metadata.get("license")
|
||||
|
||||
return Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
license=license_text,
|
||||
skill_dir=skill_file.parent,
|
||||
skill_file=skill_file,
|
||||
relative_path=relative_path or Path(skill_file.parent.name),
|
||||
category=category,
|
||||
enabled=True, # Default to enabled, actual state comes from config file
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing skill file {skill_file}: {e}")
|
||||
return None
|
||||
53
backend/packages/harness/deerflow/skills/types.py
Normal file
53
backend/packages/harness/deerflow/skills/types.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Represents a skill with its metadata and file path"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
license: str | None
|
||||
skill_dir: Path
|
||||
skill_file: Path
|
||||
relative_path: Path # Relative path from category root to skill directory
|
||||
category: str # 'public' or 'custom'
|
||||
enabled: bool = False # Whether this skill is enabled
|
||||
|
||||
@property
|
||||
def skill_path(self) -> str:
|
||||
"""Returns the relative path from the category root (skills/{category}) to this skill's directory"""
|
||||
path = self.relative_path.as_posix()
|
||||
return "" if path == "." else path
|
||||
|
||||
def get_container_path(self, container_base_path: str = "/mnt/skills") -> str:
|
||||
"""
|
||||
Get the full path to this skill in the container.
|
||||
|
||||
Args:
|
||||
container_base_path: Base path where skills are mounted in the container
|
||||
|
||||
Returns:
|
||||
Full container path to the skill directory
|
||||
"""
|
||||
category_base = f"{container_base_path}/{self.category}"
|
||||
skill_path = self.skill_path
|
||||
if skill_path:
|
||||
return f"{category_base}/{skill_path}"
|
||||
return category_base
|
||||
|
||||
def get_container_file_path(self, container_base_path: str = "/mnt/skills") -> str:
|
||||
"""
|
||||
Get the full path to this skill's main file (SKILL.md) in the container.
|
||||
|
||||
Args:
|
||||
container_base_path: Base path where skills are mounted in the container
|
||||
|
||||
Returns:
|
||||
Full container path to the skill's SKILL.md file
|
||||
"""
|
||||
return f"{self.get_container_path(container_base_path)}/SKILL.md"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Skill(name={self.name!r}, description={self.description!r}, category={self.category!r})"
|
||||
85
backend/packages/harness/deerflow/skills/validation.py
Normal file
85
backend/packages/harness/deerflow/skills/validation.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Skill frontmatter validation utilities.
|
||||
|
||||
Pure-logic validation of SKILL.md frontmatter — no FastAPI or HTTP dependencies.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
# Allowed properties in SKILL.md frontmatter
|
||||
ALLOWED_FRONTMATTER_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata", "compatibility", "version", "author"}
|
||||
|
||||
|
||||
def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]:
|
||||
"""Validate a skill directory's SKILL.md frontmatter.
|
||||
|
||||
Args:
|
||||
skill_dir: Path to the skill directory containing SKILL.md.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, message, skill_name).
|
||||
"""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
return False, "SKILL.md not found", None
|
||||
|
||||
content = skill_md.read_text()
|
||||
if not content.startswith("---"):
|
||||
return False, "No YAML frontmatter found", None
|
||||
|
||||
# Extract frontmatter
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if not match:
|
||||
return False, "Invalid frontmatter format", None
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Parse YAML frontmatter
|
||||
try:
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
return False, "Frontmatter must be a YAML dictionary", None
|
||||
except yaml.YAMLError as e:
|
||||
return False, f"Invalid YAML in frontmatter: {e}", None
|
||||
|
||||
# Check for unexpected properties
|
||||
unexpected_keys = set(frontmatter.keys()) - ALLOWED_FRONTMATTER_PROPERTIES
|
||||
if unexpected_keys:
|
||||
return False, f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}", None
|
||||
|
||||
# Check required fields
|
||||
if "name" not in frontmatter:
|
||||
return False, "Missing 'name' in frontmatter", None
|
||||
if "description" not in frontmatter:
|
||||
return False, "Missing 'description' in frontmatter", None
|
||||
|
||||
# Validate name
|
||||
name = frontmatter.get("name", "")
|
||||
if not isinstance(name, str):
|
||||
return False, f"Name must be a string, got {type(name).__name__}", None
|
||||
name = name.strip()
|
||||
if not name:
|
||||
return False, "Name cannot be empty", None
|
||||
|
||||
# Check naming convention (hyphen-case: lowercase with hyphens)
|
||||
if not re.match(r"^[a-z0-9-]+$", name):
|
||||
return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)", None
|
||||
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||
return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens", None
|
||||
if len(name) > 64:
|
||||
return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters.", None
|
||||
|
||||
# Validate description
|
||||
description = frontmatter.get("description", "")
|
||||
if not isinstance(description, str):
|
||||
return False, f"Description must be a string, got {type(description).__name__}", None
|
||||
description = description.strip()
|
||||
if description:
|
||||
if "<" in description or ">" in description:
|
||||
return False, "Description cannot contain angle brackets (< or >)", None
|
||||
if len(description) > 1024:
|
||||
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
|
||||
|
||||
return True, "Skill is valid!", name
|
||||
11
backend/packages/harness/deerflow/subagents/__init__.py
Normal file
11
backend/packages/harness/deerflow/subagents/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .config import SubagentConfig
|
||||
from .executor import SubagentExecutor, SubagentResult
|
||||
from .registry import get_subagent_config, list_subagents
|
||||
|
||||
__all__ = [
|
||||
"SubagentConfig",
|
||||
"SubagentExecutor",
|
||||
"SubagentResult",
|
||||
"get_subagent_config",
|
||||
"list_subagents",
|
||||
]
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Built-in subagent configurations."""
|
||||
|
||||
from .bash_agent import BASH_AGENT_CONFIG
|
||||
from .general_purpose import GENERAL_PURPOSE_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"GENERAL_PURPOSE_CONFIG",
|
||||
"BASH_AGENT_CONFIG",
|
||||
]
|
||||
|
||||
# Registry of built-in subagents
|
||||
BUILTIN_SUBAGENTS = {
|
||||
"general-purpose": GENERAL_PURPOSE_CONFIG,
|
||||
"bash": BASH_AGENT_CONFIG,
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Bash command execution subagent configuration."""
|
||||
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
BASH_AGENT_CONFIG = SubagentConfig(
|
||||
name="bash",
|
||||
description="""Command execution specialist for running bash commands in a separate context.
|
||||
|
||||
Use this subagent when:
|
||||
- You need to run a series of related bash commands
|
||||
- Terminal operations like git, npm, docker, etc.
|
||||
- Command output is verbose and would clutter main context
|
||||
- Build, test, or deployment operations
|
||||
|
||||
Do NOT use for simple single commands - use bash tool directly instead.""",
|
||||
system_prompt="""You are a bash command execution specialist. Execute the requested commands carefully and report results clearly.
|
||||
|
||||
<guidelines>
|
||||
- Execute commands one at a time when they depend on each other
|
||||
- Use parallel execution when commands are independent
|
||||
- Report both stdout and stderr when relevant
|
||||
- Handle errors gracefully and explain what went wrong
|
||||
- Use absolute paths for file operations
|
||||
- Be cautious with destructive operations (rm, overwrite, etc.)
|
||||
</guidelines>
|
||||
|
||||
<output_format>
|
||||
For each command or group of commands:
|
||||
1. What was executed
|
||||
2. The result (success/failure)
|
||||
3. Relevant output (summarized if verbose)
|
||||
4. Any errors or warnings
|
||||
</output_format>
|
||||
|
||||
<working_directory>
|
||||
You have access to the sandbox environment:
|
||||
- User uploads: `/mnt/user-data/uploads`
|
||||
- User workspace: `/mnt/user-data/workspace`
|
||||
- Output files: `/mnt/user-data/outputs`
|
||||
</working_directory>
|
||||
""",
|
||||
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"],
|
||||
model="inherit",
|
||||
max_turns=30,
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
"""General-purpose subagent configuration."""
|
||||
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
GENERAL_PURPOSE_CONFIG = SubagentConfig(
|
||||
name="general-purpose",
|
||||
description="""A capable agent for complex, multi-step tasks that require both exploration and action.
|
||||
|
||||
Use this subagent when:
|
||||
- The task requires both exploration and modification
|
||||
- Complex reasoning is needed to interpret results
|
||||
- Multiple dependent steps must be executed
|
||||
- The task would benefit from isolated context management
|
||||
|
||||
Do NOT use for simple, single-step operations.""",
|
||||
system_prompt="""You are a general-purpose subagent working on a delegated task. Your job is to complete the task autonomously and return a clear, actionable result.
|
||||
|
||||
<guidelines>
|
||||
- Focus on completing the delegated task efficiently
|
||||
- Use available tools as needed to accomplish the goal
|
||||
- Think step by step but act decisively
|
||||
- If you encounter issues, explain them clearly in your response
|
||||
- Return a concise summary of what you accomplished
|
||||
- Do NOT ask for clarification - work with the information provided
|
||||
</guidelines>
|
||||
|
||||
<output_format>
|
||||
When you complete the task, provide:
|
||||
1. A brief summary of what was accomplished
|
||||
2. Key findings or results
|
||||
3. Any relevant file paths, data, or artifacts created
|
||||
4. Issues encountered (if any)
|
||||
5. Citations: Use `[citation:Title](URL)` format for external sources
|
||||
</output_format>
|
||||
|
||||
<working_directory>
|
||||
You have access to the same sandbox environment as the parent agent:
|
||||
- User uploads: `/mnt/user-data/uploads`
|
||||
- User workspace: `/mnt/user-data/workspace`
|
||||
- Output files: `/mnt/user-data/outputs`
|
||||
</working_directory>
|
||||
""",
|
||||
tools=None, # Inherit all tools from parent
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"], # Prevent nesting and clarification
|
||||
model="inherit",
|
||||
max_turns=50,
|
||||
)
|
||||
28
backend/packages/harness/deerflow/subagents/config.py
Normal file
28
backend/packages/harness/deerflow/subagents/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Subagent configuration definitions."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubagentConfig:
|
||||
"""Configuration for a subagent.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the subagent.
|
||||
description: When Claude should delegate to this subagent.
|
||||
system_prompt: The system prompt that guides the subagent's behavior.
|
||||
tools: Optional list of tool names to allow. If None, inherits all tools.
|
||||
disallowed_tools: Optional list of tool names to deny.
|
||||
model: Model to use - 'inherit' uses parent's model.
|
||||
max_turns: Maximum number of agent turns before stopping.
|
||||
timeout_seconds: Maximum execution time in seconds (default: 900 = 15 minutes).
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
tools: list[str] | None = None
|
||||
disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"])
|
||||
model: str = "inherit"
|
||||
max_turns: int = 50
|
||||
timeout_seconds: int = 900
|
||||
486
backend/packages/harness/deerflow/subagents/executor.py
Normal file
486
backend/packages/harness/deerflow/subagents/executor.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""Subagent execution engine."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentStatus(Enum):
|
||||
"""Status of a subagent execution."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubagentResult:
|
||||
"""Result of a subagent execution.
|
||||
|
||||
Attributes:
|
||||
task_id: Unique identifier for this execution.
|
||||
trace_id: Trace ID for distributed tracing (links parent and subagent logs).
|
||||
status: Current status of the execution.
|
||||
result: The final result message (if completed).
|
||||
error: Error message (if failed).
|
||||
started_at: When execution started.
|
||||
completed_at: When execution completed.
|
||||
ai_messages: List of complete AI messages (as dicts) generated during execution.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
trace_id: str
|
||||
status: SubagentStatus
|
||||
result: str | None = None
|
||||
error: str | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
ai_messages: list[dict[str, Any]] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize mutable defaults."""
|
||||
if self.ai_messages is None:
|
||||
self.ai_messages = []
|
||||
|
||||
|
||||
# Global storage for background task results
|
||||
_background_tasks: dict[str, SubagentResult] = {}
|
||||
_background_tasks_lock = threading.Lock()
|
||||
|
||||
# Thread pool for background task scheduling and orchestration
|
||||
_scheduler_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-scheduler-")
|
||||
|
||||
# Thread pool for actual subagent execution (with timeout support)
|
||||
# Larger pool to avoid blocking when scheduler submits execution tasks
|
||||
_execution_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-exec-")
|
||||
|
||||
|
||||
def _filter_tools(
|
||||
all_tools: list[BaseTool],
|
||||
allowed: list[str] | None,
|
||||
disallowed: list[str] | None,
|
||||
) -> list[BaseTool]:
|
||||
"""Filter tools based on subagent configuration.
|
||||
|
||||
Args:
|
||||
all_tools: List of all available tools.
|
||||
allowed: Optional allowlist of tool names. If provided, only these tools are included.
|
||||
disallowed: Optional denylist of tool names. These tools are always excluded.
|
||||
|
||||
Returns:
|
||||
Filtered list of tools.
|
||||
"""
|
||||
filtered = all_tools
|
||||
|
||||
# Apply allowlist if specified
|
||||
if allowed is not None:
|
||||
allowed_set = set(allowed)
|
||||
filtered = [t for t in filtered if t.name in allowed_set]
|
||||
|
||||
# Apply denylist
|
||||
if disallowed is not None:
|
||||
disallowed_set = set(disallowed)
|
||||
filtered = [t for t in filtered if t.name not in disallowed_set]
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def _get_model_name(config: SubagentConfig, parent_model: str | None) -> str | None:
|
||||
"""Resolve the model name for a subagent.
|
||||
|
||||
Args:
|
||||
config: Subagent configuration.
|
||||
parent_model: The parent agent's model name.
|
||||
|
||||
Returns:
|
||||
Model name to use, or None to use default.
|
||||
"""
|
||||
if config.model == "inherit":
|
||||
return parent_model
|
||||
return config.model
|
||||
|
||||
|
||||
class SubagentExecutor:
|
||||
"""Executor for running subagents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SubagentConfig,
|
||||
tools: list[BaseTool],
|
||||
parent_model: str | None = None,
|
||||
sandbox_state: SandboxState | None = None,
|
||||
thread_data: ThreadDataState | None = None,
|
||||
thread_id: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
):
|
||||
"""Initialize the executor.
|
||||
|
||||
Args:
|
||||
config: Subagent configuration.
|
||||
tools: List of all available tools (will be filtered).
|
||||
parent_model: The parent agent's model name for inheritance.
|
||||
sandbox_state: Sandbox state from parent agent.
|
||||
thread_data: Thread data from parent agent.
|
||||
thread_id: Thread ID for sandbox operations.
|
||||
trace_id: Trace ID from parent for distributed tracing.
|
||||
"""
|
||||
self.config = config
|
||||
self.parent_model = parent_model
|
||||
self.sandbox_state = sandbox_state
|
||||
self.thread_data = thread_data
|
||||
self.thread_id = thread_id
|
||||
# Generate trace_id if not provided (for top-level calls)
|
||||
self.trace_id = trace_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Filter tools based on config
|
||||
self.tools = _filter_tools(
|
||||
tools,
|
||||
config.tools,
|
||||
config.disallowed_tools,
|
||||
)
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
|
||||
|
||||
def _create_agent(self):
|
||||
"""Create the agent instance."""
|
||||
model_name = _get_model_name(self.config, self.parent_model)
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||
|
||||
# Reuse shared middleware composition with lead agent.
|
||||
middlewares = build_subagent_runtime_middlewares(lazy_init=True)
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=self.tools,
|
||||
middleware=middlewares,
|
||||
system_prompt=self.config.system_prompt,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
def _build_initial_state(self, task: str) -> dict[str, Any]:
|
||||
"""Build the initial state for agent execution.
|
||||
|
||||
Args:
|
||||
task: The task description.
|
||||
|
||||
Returns:
|
||||
Initial state dictionary.
|
||||
"""
|
||||
state: dict[str, Any] = {
|
||||
"messages": [HumanMessage(content=task)],
|
||||
}
|
||||
|
||||
# Pass through sandbox and thread data from parent
|
||||
if self.sandbox_state is not None:
|
||||
state["sandbox"] = self.sandbox_state
|
||||
if self.thread_data is not None:
|
||||
state["thread_data"] = self.thread_data
|
||||
|
||||
return state
|
||||
|
||||
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description for the subagent.
|
||||
result_holder: Optional pre-created result object to update during execution.
|
||||
|
||||
Returns:
|
||||
SubagentResult with the execution result.
|
||||
"""
|
||||
if result_holder is not None:
|
||||
# Use the provided result holder (for async execution with real-time updates)
|
||||
result = result_holder
|
||||
else:
|
||||
# Create a new result for synchronous execution
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
result = SubagentResult(
|
||||
task_id=task_id,
|
||||
trace_id=self.trace_id,
|
||||
status=SubagentStatus.RUNNING,
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
|
||||
try:
|
||||
agent = self._create_agent()
|
||||
state = self._build_initial_state(task)
|
||||
|
||||
# Build config with thread_id for sandbox access and recursion limit
|
||||
run_config: RunnableConfig = {
|
||||
"recursion_limit": self.config.max_turns,
|
||||
}
|
||||
context = {}
|
||||
if self.thread_id:
|
||||
run_config["configurable"] = {"thread_id": self.thread_id}
|
||||
context["thread_id"] = self.thread_id
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")
|
||||
|
||||
# Use stream instead of invoke to get real-time updates
|
||||
# This allows us to collect AI messages as they are generated
|
||||
final_state = None
|
||||
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||
final_state = chunk
|
||||
|
||||
# Extract AI messages from the current state
|
||||
messages = chunk.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Check if this is a new AI message
|
||||
if isinstance(last_message, AIMessage):
|
||||
# Convert message to dict for serialization
|
||||
message_dict = last_message.model_dump()
|
||||
# Only add if it's not already in the list (avoid duplicates)
|
||||
# Check by comparing message IDs if available, otherwise compare full dict
|
||||
message_id = message_dict.get("id")
|
||||
is_duplicate = False
|
||||
if message_id:
|
||||
is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages)
|
||||
else:
|
||||
is_duplicate = message_dict in result.ai_messages
|
||||
|
||||
if not is_duplicate:
|
||||
result.ai_messages.append(message_dict)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
||||
|
||||
if final_state is None:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||
result.result = "No response generated"
|
||||
else:
|
||||
# Extract the final message - find the last AIMessage
|
||||
messages = final_state.get("messages", [])
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
|
||||
|
||||
# Find the last AIMessage in the conversation
|
||||
last_ai_message = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
last_ai_message = msg
|
||||
break
|
||||
|
||||
if last_ai_message is not None:
|
||||
content = last_ai_message.content
|
||||
# Handle both str and list content types for the final result
|
||||
if isinstance(content, str):
|
||||
result.result = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from list of content blocks for final result only
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||
else:
|
||||
result.result = str(content)
|
||||
elif messages:
|
||||
# Fallback: use the last message if no AIMessage found
|
||||
last_message = messages[-1]
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
||||
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
|
||||
else:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
||||
result.result = "No response generated"
|
||||
|
||||
result.status = SubagentStatus.COMPLETED
|
||||
result.completed_at = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
|
||||
result.status = SubagentStatus.FAILED
|
||||
result.error = str(e)
|
||||
result.completed_at = datetime.now()
|
||||
|
||||
return result
|
||||
|
||||
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task synchronously (wrapper around async execution).
|
||||
|
||||
This method runs the async execution in a new event loop, allowing
|
||||
asynchronous tools (like MCP tools) to be used within the thread pool.
|
||||
|
||||
Args:
|
||||
task: The task description for the subagent.
|
||||
result_holder: Optional pre-created result object to update during execution.
|
||||
|
||||
Returns:
|
||||
SubagentResult with the execution result.
|
||||
"""
|
||||
# Run the async execution in a new event loop
|
||||
# This is necessary because:
|
||||
# 1. We may have async-only tools (like MCP tools)
|
||||
# 2. We're running inside a ThreadPoolExecutor which doesn't have an event loop
|
||||
#
|
||||
# Note: _aexecute() catches all exceptions internally, so this outer
|
||||
# try-except only handles asyncio.run() failures (e.g., if called from
|
||||
# an async context where an event loop already exists). Subagent execution
|
||||
# errors are handled within _aexecute() and returned as FAILED status.
|
||||
try:
|
||||
return asyncio.run(self._aexecute(task, result_holder))
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} execution failed")
|
||||
# Create a result with error if we don't have one
|
||||
if result_holder is not None:
|
||||
result = result_holder
|
||||
else:
|
||||
result = SubagentResult(
|
||||
task_id=str(uuid.uuid4())[:8],
|
||||
trace_id=self.trace_id,
|
||||
status=SubagentStatus.FAILED,
|
||||
)
|
||||
result.status = SubagentStatus.FAILED
|
||||
result.error = str(e)
|
||||
result.completed_at = datetime.now()
|
||||
return result
|
||||
|
||||
def execute_async(self, task: str, task_id: str | None = None) -> str:
|
||||
"""Start a task execution in the background.
|
||||
|
||||
Args:
|
||||
task: The task description for the subagent.
|
||||
task_id: Optional task ID to use. If not provided, a random UUID will be generated.
|
||||
|
||||
Returns:
|
||||
Task ID that can be used to check status later.
|
||||
"""
|
||||
# Use provided task_id or generate a new one
|
||||
if task_id is None:
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create initial pending result
|
||||
result = SubagentResult(
|
||||
task_id=task_id,
|
||||
trace_id=self.trace_id,
|
||||
status=SubagentStatus.PENDING,
|
||||
)
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution, task_id={task_id}, timeout={self.config.timeout_seconds}s")
|
||||
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id] = result
|
||||
|
||||
# Submit to scheduler pool
|
||||
def run_task():
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = SubagentStatus.RUNNING
|
||||
_background_tasks[task_id].started_at = datetime.now()
|
||||
result_holder = _background_tasks[task_id]
|
||||
|
||||
try:
|
||||
# Submit execution to execution pool with timeout
|
||||
# Pass result_holder so execute() can update it in real-time
|
||||
execution_future: Future = _execution_pool.submit(self.execute, task, result_holder)
|
||||
try:
|
||||
# Wait for execution with timeout
|
||||
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = exec_result.status
|
||||
_background_tasks[task_id].result = exec_result.result
|
||||
_background_tasks[task_id].error = exec_result.error
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
_background_tasks[task_id].ai_messages = exec_result.ai_messages
|
||||
except FuturesTimeoutError:
|
||||
logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s")
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
|
||||
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
# Cancel the future (best effort - may not stop the actual execution)
|
||||
execution_future.cancel()
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = SubagentStatus.FAILED
|
||||
_background_tasks[task_id].error = str(e)
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
|
||||
_scheduler_pool.submit(run_task)
|
||||
return task_id
|
||||
|
||||
|
||||
MAX_CONCURRENT_SUBAGENTS = 3
|
||||
|
||||
|
||||
def get_background_task_result(task_id: str) -> SubagentResult | None:
|
||||
"""Get the result of a background task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID returned by execute_async.
|
||||
|
||||
Returns:
|
||||
SubagentResult if found, None otherwise.
|
||||
"""
|
||||
with _background_tasks_lock:
|
||||
return _background_tasks.get(task_id)
|
||||
|
||||
|
||||
def list_background_tasks() -> list[SubagentResult]:
|
||||
"""List all background tasks.
|
||||
|
||||
Returns:
|
||||
List of all SubagentResult instances.
|
||||
"""
|
||||
with _background_tasks_lock:
|
||||
return list(_background_tasks.values())
|
||||
|
||||
|
||||
def cleanup_background_task(task_id: str) -> None:
|
||||
"""Remove a completed task from background tasks.
|
||||
|
||||
Should be called by task_tool after it finishes polling and returns the result.
|
||||
This prevents memory leaks from accumulated completed tasks.
|
||||
|
||||
Only removes tasks that are in a terminal state (COMPLETED/FAILED/TIMED_OUT)
|
||||
to avoid race conditions with the background executor still updating the task entry.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to remove.
|
||||
"""
|
||||
with _background_tasks_lock:
|
||||
result = _background_tasks.get(task_id)
|
||||
if result is None:
|
||||
# Nothing to clean up; may have been removed already.
|
||||
logger.debug("Requested cleanup for unknown background task %s", task_id)
|
||||
return
|
||||
|
||||
# Only clean up tasks that are in a terminal state to avoid races with
|
||||
# the background executor still updating the task entry.
|
||||
is_terminal_status = result.status in {
|
||||
SubagentStatus.COMPLETED,
|
||||
SubagentStatus.FAILED,
|
||||
SubagentStatus.TIMED_OUT,
|
||||
}
|
||||
if is_terminal_status or result.completed_at is not None:
|
||||
del _background_tasks[task_id]
|
||||
logger.debug("Cleaned up background task: %s", task_id)
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping cleanup for non-terminal background task %s (status=%s)",
|
||||
task_id,
|
||||
result.status.value if hasattr(result.status, "value") else result.status,
|
||||
)
|
||||
52
backend/packages/harness/deerflow/subagents/registry.py
Normal file
52
backend/packages/harness/deerflow/subagents/registry.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Subagent registry for managing available subagents."""
|
||||
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
"""Get a subagent configuration by name, with config.yaml overrides applied.
|
||||
|
||||
Args:
|
||||
name: The name of the subagent.
|
||||
|
||||
Returns:
|
||||
SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
|
||||
"""
|
||||
config = BUILTIN_SUBAGENTS.get(name)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
# Apply timeout override from config.yaml (lazy import to avoid circular deps)
|
||||
from deerflow.config.subagents_config import get_subagents_app_config
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
effective_timeout = app_config.get_timeout_for(name)
|
||||
if effective_timeout != config.timeout_seconds:
|
||||
logger.debug(f"Subagent '{name}': timeout overridden by config.yaml ({config.timeout_seconds}s -> {effective_timeout}s)")
|
||||
config = replace(config, timeout_seconds=effective_timeout)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def list_subagents() -> list[SubagentConfig]:
|
||||
"""List all available subagent configurations (with config.yaml overrides applied).
|
||||
|
||||
Returns:
|
||||
List of all registered SubagentConfig instances.
|
||||
"""
|
||||
return [get_subagent_config(name) for name in BUILTIN_SUBAGENTS]
|
||||
|
||||
|
||||
def get_subagent_names() -> list[str]:
|
||||
"""Get all available subagent names.
|
||||
|
||||
Returns:
|
||||
List of subagent names.
|
||||
"""
|
||||
return list(BUILTIN_SUBAGENTS.keys())
|
||||
3
backend/packages/harness/deerflow/tools/__init__.py
Normal file
3
backend/packages/harness/deerflow/tools/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .tools import get_available_tools
|
||||
|
||||
__all__ = ["get_available_tools"]
|
||||
13
backend/packages/harness/deerflow/tools/builtins/__init__.py
Normal file
13
backend/packages/harness/deerflow/tools/builtins/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .clarification_tool import ask_clarification_tool
|
||||
from .present_file_tool import present_file_tool
|
||||
from .setup_agent_tool import setup_agent
|
||||
from .task_tool import task_tool
|
||||
from .view_image_tool import view_image_tool
|
||||
|
||||
__all__ = [
|
||||
"setup_agent",
|
||||
"present_file_tool",
|
||||
"ask_clarification_tool",
|
||||
"view_image_tool",
|
||||
"task_tool",
|
||||
]
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
|
||||
@tool("ask_clarification", parse_docstring=True, return_direct=True)
|
||||
def ask_clarification_tool(
|
||||
question: str,
|
||||
clarification_type: Literal[
|
||||
"missing_info",
|
||||
"ambiguous_requirement",
|
||||
"approach_choice",
|
||||
"risk_confirmation",
|
||||
"suggestion",
|
||||
],
|
||||
context: str | None = None,
|
||||
options: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Ask the user for clarification when you need more information to proceed.
|
||||
|
||||
Use this tool when you encounter situations where you cannot proceed without user input:
|
||||
|
||||
- **Missing information**: Required details not provided (e.g., file paths, URLs, specific requirements)
|
||||
- **Ambiguous requirements**: Multiple valid interpretations exist
|
||||
- **Approach choices**: Several valid approaches exist and you need user preference
|
||||
- **Risky operations**: Destructive actions that need explicit confirmation (e.g., deleting files, modifying production)
|
||||
- **Suggestions**: You have a recommendation but want user approval before proceeding
|
||||
|
||||
The execution will be interrupted and the question will be presented to the user.
|
||||
Wait for the user's response before continuing.
|
||||
|
||||
When to use ask_clarification:
|
||||
- You need information that wasn't provided in the user's request
|
||||
- The requirement can be interpreted in multiple ways
|
||||
- Multiple valid implementation approaches exist
|
||||
- You're about to perform a potentially dangerous operation
|
||||
- You have a recommendation but need user approval
|
||||
|
||||
Best practices:
|
||||
- Ask ONE clarification at a time for clarity
|
||||
- Be specific and clear in your question
|
||||
- Don't make assumptions when clarification is needed
|
||||
- For risky operations, ALWAYS ask for confirmation
|
||||
- After calling this tool, execution will be interrupted automatically
|
||||
|
||||
Args:
|
||||
question: The clarification question to ask the user. Be specific and clear.
|
||||
clarification_type: The type of clarification needed (missing_info, ambiguous_requirement, approach_choice, risk_confirmation, suggestion).
|
||||
context: Optional context explaining why clarification is needed. Helps the user understand the situation.
|
||||
options: Optional list of choices (for approach_choice or suggestion types). Present clear options for the user to choose from.
|
||||
"""
|
||||
# This is a placeholder implementation
|
||||
# The actual logic is handled by ClarificationMiddleware which intercepts this tool call
|
||||
# and interrupts execution to present the question to the user
|
||||
return "Clarification request processed by middleware"
|
||||
@@ -0,0 +1,100 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
|
||||
def _normalize_presented_filepath(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
filepath: str,
|
||||
) -> str:
|
||||
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
|
||||
|
||||
Accepts either:
|
||||
- A virtual sandbox path such as `/mnt/user-data/outputs/report.md`
|
||||
- A host-side thread outputs path such as
|
||||
`/app/backend/.deer-flow/threads/<thread>/user-data/outputs/report.md`
|
||||
|
||||
Returns:
|
||||
The normalized virtual path.
|
||||
|
||||
Raises:
|
||||
ValueError: If runtime metadata is missing or the path is outside the
|
||||
current thread's outputs directory.
|
||||
"""
|
||||
if runtime.state is None:
|
||||
raise ValueError("Thread runtime state is not available")
|
||||
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is not available in runtime context")
|
||||
|
||||
thread_data = runtime.state.get("thread_data") or {}
|
||||
outputs_path = thread_data.get("outputs_path")
|
||||
if not outputs_path:
|
||||
raise ValueError("Thread outputs path is not available in runtime state")
|
||||
|
||||
outputs_dir = Path(outputs_path).resolve()
|
||||
stripped = filepath.lstrip("/")
|
||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
||||
else:
|
||||
actual_path = Path(filepath).expanduser().resolve()
|
||||
|
||||
try:
|
||||
relative_path = actual_path.relative_to(outputs_dir)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Only files in {OUTPUTS_VIRTUAL_PREFIX} can be presented: {filepath}") from exc
|
||||
|
||||
return f"{OUTPUTS_VIRTUAL_PREFIX}/{relative_path.as_posix()}"
|
||||
|
||||
|
||||
@tool("present_files", parse_docstring=True)
|
||||
def present_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
filepaths: list[str],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Make files visible to the user for viewing and rendering in the client interface.
|
||||
|
||||
When to use the present_files tool:
|
||||
|
||||
- Making any file available for the user to view, download, or interact with
|
||||
- Presenting multiple related files at once
|
||||
- After creating files that should be presented to the user
|
||||
|
||||
When NOT to use the present_files tool:
|
||||
- When you only need to read file contents for your own processing
|
||||
- For temporary or intermediate files not meant for user viewing
|
||||
|
||||
Notes:
|
||||
- You should call this tool after creating files and moving them to the `/mnt/user-data/outputs` directory.
|
||||
- This tool can be safely called in parallel with other tools. State updates are handled by a reducer to prevent conflicts.
|
||||
|
||||
Args:
|
||||
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
|
||||
"""
|
||||
try:
|
||||
normalized_paths = [_normalize_presented_filepath(runtime, filepath) for filepath in filepaths]
|
||||
except ValueError as exc:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: {exc}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# The merge_artifacts reducer will handle merging and deduplication
|
||||
return Command(
|
||||
update={
|
||||
"artifacts": normalized_paths,
|
||||
"messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)],
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import ToolRuntime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
def setup_agent(
|
||||
soul: str,
|
||||
description: str,
|
||||
runtime: ToolRuntime,
|
||||
) -> Command:
|
||||
"""Setup the custom DeerFlow agent.
|
||||
|
||||
Args:
|
||||
soul: Full SOUL.md content defining the agent's personality and behavior.
|
||||
description: One-line description of what the agent does.
|
||||
"""
|
||||
|
||||
agent_name: str | None = runtime.context.get("agent_name")
|
||||
|
||||
try:
|
||||
paths = get_paths()
|
||||
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if agent_name:
|
||||
# If agent_name is provided, we are creating a custom agent in the agents/ directory
|
||||
config_data: dict = {"name": agent_name}
|
||||
if description:
|
||||
config_data["description"] = description
|
||||
|
||||
config_file = agent_dir / "config.yaml"
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
soul_file = agent_dir / "SOUL.md"
|
||||
soul_file.write_text(soul, encoding="utf-8")
|
||||
|
||||
logger.info(f"[agent_creator] Created agent '{agent_name}' at {agent_dir}")
|
||||
return Command(
|
||||
update={
|
||||
"created_agent_name": agent_name,
|
||||
"messages": [ToolMessage(content=f"Agent '{agent_name}' created successfully!", tool_call_id=runtime.tool_call_id)],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import shutil
|
||||
|
||||
if agent_name and agent_dir.exists():
|
||||
# Cleanup the custom agent directory only if it was created but an error occurred during setup
|
||||
shutil.rmtree(agent_dir)
|
||||
logger.error(f"[agent_creator] Failed to create agent '{agent_name}': {e}", exc_info=True)
|
||||
return Command(update={"messages": [ToolMessage(content=f"Error: {e}", tool_call_id=runtime.tool_call_id)]})
|
||||
195
backend/packages/harness/deerflow/tools/builtins/task_tool.py
Normal file
195
backend/packages/harness/deerflow/tools/builtins/task_tool.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Task tool for delegating work to subagents."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import replace
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.subagents import SubagentExecutor, get_subagent_config
|
||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool("task", parse_docstring=True)
|
||||
def task_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
prompt: str,
|
||||
subagent_type: Literal["general-purpose", "bash"],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
max_turns: int | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to a specialized subagent that runs in its own context.
|
||||
|
||||
Subagents help you:
|
||||
- Preserve context by keeping exploration and implementation separate
|
||||
- Handle complex multi-step tasks autonomously
|
||||
- Execute commands or operations in isolated contexts
|
||||
|
||||
Available subagent types:
|
||||
- **general-purpose**: A capable agent for complex, multi-step tasks that require
|
||||
both exploration and action. Use when the task requires complex reasoning,
|
||||
multiple dependent steps, or would benefit from isolated context.
|
||||
- **bash**: Command execution specialist for running bash commands. Use for
|
||||
git operations, build processes, or when command output would be verbose.
|
||||
|
||||
When to use this tool:
|
||||
- Complex tasks requiring multiple steps or tools
|
||||
- Tasks that produce verbose output
|
||||
- When you want to isolate context from the main conversation
|
||||
- Parallel research or exploration tasks
|
||||
|
||||
When NOT to use this tool:
|
||||
- Simple, single-step operations (use tools directly)
|
||||
- Tasks requiring user interaction or clarification
|
||||
|
||||
Args:
|
||||
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
|
||||
"""
|
||||
# Get subagent configuration
|
||||
config = get_subagent_config(subagent_type)
|
||||
if config is None:
|
||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: general-purpose, bash"
|
||||
|
||||
# Build config overrides
|
||||
overrides: dict = {}
|
||||
|
||||
skills_section = get_skills_prompt_section()
|
||||
if skills_section:
|
||||
overrides["system_prompt"] = config.system_prompt + "\n\n" + skills_section
|
||||
|
||||
if max_turns is not None:
|
||||
overrides["max_turns"] = max_turns
|
||||
|
||||
if overrides:
|
||||
config = replace(config, **overrides)
|
||||
|
||||
# Extract parent context from runtime
|
||||
sandbox_state = None
|
||||
thread_data = None
|
||||
thread_id = None
|
||||
parent_model = None
|
||||
trace_id = None
|
||||
|
||||
if runtime is not None:
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
thread_data = runtime.state.get("thread_data")
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
|
||||
# Try to get parent model from configurable
|
||||
metadata = runtime.config.get("metadata", {})
|
||||
parent_model = metadata.get("model_name")
|
||||
|
||||
# Get or generate trace_id for distributed tracing
|
||||
trace_id = metadata.get("trace_id") or str(uuid.uuid4())[:8]
|
||||
|
||||
# Get available tools (excluding task tool to prevent nesting)
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
|
||||
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
||||
tools = get_available_tools(model_name=parent_model, subagent_enabled=False)
|
||||
|
||||
# Create executor
|
||||
executor = SubagentExecutor(
|
||||
config=config,
|
||||
tools=tools,
|
||||
parent_model=parent_model,
|
||||
sandbox_state=sandbox_state,
|
||||
thread_data=thread_data,
|
||||
thread_id=thread_id,
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
# Start background execution (always async to prevent blocking)
|
||||
# Use tool_call_id as task_id for better traceability
|
||||
task_id = executor.execute_async(prompt, task_id=tool_call_id)
|
||||
|
||||
# Poll for task completion in backend (removes need for LLM to poll)
|
||||
poll_count = 0
|
||||
last_status = None
|
||||
last_message_count = 0 # Track how many AI messages we've already sent
|
||||
# Polling timeout: execution timeout + 60s buffer, checked every 5s
|
||||
max_poll_count = (config.timeout_seconds + 60) // 5
|
||||
|
||||
logger.info(f"[trace={trace_id}] Started background task {task_id} (subagent={subagent_type}, timeout={config.timeout_seconds}s, polling_limit={max_poll_count} polls)")
|
||||
|
||||
writer = get_stream_writer()
|
||||
# Send Task Started message'
|
||||
writer({"type": "task_started", "task_id": task_id, "description": description})
|
||||
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
||||
cleanup_background_task(task_id)
|
||||
return f"Error: Task {task_id} disappeared from background tasks"
|
||||
|
||||
# Log status changes for debugging
|
||||
if result.status != last_status:
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||
last_status = result.status
|
||||
|
||||
# Check for new AI messages and send task_running events
|
||||
current_message_count = len(result.ai_messages)
|
||||
if current_message_count > last_message_count:
|
||||
# Send task_running event for each new message
|
||||
for i in range(last_message_count, current_message_count):
|
||||
message = result.ai_messages[i]
|
||||
writer(
|
||||
{
|
||||
"type": "task_running",
|
||||
"task_id": task_id,
|
||||
"message": message,
|
||||
"message_index": i + 1, # 1-based index for display
|
||||
"total_messages": current_message_count,
|
||||
}
|
||||
)
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
|
||||
# Still running, wait before next poll
|
||||
time.sleep(5) # Poll every 5 seconds
|
||||
poll_count += 1
|
||||
|
||||
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
||||
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
||||
# This catches edge cases where the background task gets stuck
|
||||
# Note: We don't call cleanup_background_task here because the task may
|
||||
# still be running in the background. The cleanup will happen when the
|
||||
# executor completes and sets a terminal status.
|
||||
if poll_count > max_poll_count:
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
@@ -0,0 +1,94 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.sandbox.tools import get_thread_data, replace_virtual_path
|
||||
|
||||
|
||||
@tool("view_image", parse_docstring=True)
|
||||
def view_image_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
image_path: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Read an image file.
|
||||
|
||||
Use this tool to read an image file and make it available for display.
|
||||
|
||||
When to use the view_image tool:
|
||||
- When you need to view an image file.
|
||||
|
||||
When NOT to use the view_image tool:
|
||||
- For non-image files (use present_files instead)
|
||||
- For multiple files at once (use present_files instead)
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
|
||||
"""
|
||||
# Replace virtual path with actual path
|
||||
# /mnt/user-data/* paths are mapped to thread-specific directories
|
||||
thread_data = get_thread_data(runtime)
|
||||
actual_path = replace_virtual_path(image_path, thread_data)
|
||||
|
||||
# Validate that the path is absolute
|
||||
path = Path(actual_path)
|
||||
if not path.is_absolute():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Path must be absolute, got: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate that the file exists
|
||||
if not path.exists():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Image file not found: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate that it's a file (not a directory)
|
||||
if not path.is_file():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Path is not a file: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate image extension
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(valid_extensions)}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Detect MIME type from file extension
|
||||
mime_type, _ = mimetypes.guess_type(actual_path)
|
||||
if mime_type is None:
|
||||
# Fallback to default MIME types for common image formats
|
||||
extension_to_mime = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
mime_type = extension_to_mime.get(path.suffix.lower(), "application/octet-stream")
|
||||
|
||||
# Read image file and convert to base64
|
||||
try:
|
||||
with open(actual_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error reading image file: {str(e)}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Update viewed_images in state
|
||||
# The merge_viewed_images reducer will handle merging with existing images
|
||||
new_viewed_images = {image_path: {"base64": image_base64, "mime_type": mime_type}}
|
||||
|
||||
return Command(
|
||||
update={"viewed_images": new_viewed_images, "messages": [ToolMessage("Successfully read image", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
84
backend/packages/harness/deerflow/tools/tools.py
Normal file
84
backend/packages/harness/deerflow/tools/tools.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_TOOLS = [
|
||||
present_file_tool,
|
||||
ask_clarification_tool,
|
||||
]
|
||||
|
||||
SUBAGENT_TOOLS = [
|
||||
task_tool,
|
||||
# task_status_tool is no longer exposed to LLM (backend handles polling internally)
|
||||
]
|
||||
|
||||
|
||||
def get_available_tools(
|
||||
groups: list[str] | None = None,
|
||||
include_mcp: bool = True,
|
||||
model_name: str | None = None,
|
||||
subagent_enabled: bool = False,
|
||||
) -> list[BaseTool]:
|
||||
"""Get all available tools from config.
|
||||
|
||||
Note: MCP tools should be initialized at application startup using
|
||||
`initialize_mcp_tools()` from deerflow.mcp module.
|
||||
|
||||
Args:
|
||||
groups: Optional list of tool groups to filter by.
|
||||
include_mcp: Whether to include tools from MCP servers (default: True).
|
||||
model_name: Optional model name to determine if vision tools should be included.
|
||||
subagent_enabled: Whether to include subagent tools (task, task_status).
|
||||
|
||||
Returns:
|
||||
List of available tools.
|
||||
"""
|
||||
config = get_app_config()
|
||||
loaded_tools = [resolve_variable(tool.use, BaseTool) for tool in config.tools if groups is None or tool.group in groups]
|
||||
|
||||
# Get cached MCP tools if enabled
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of config.extensions
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when loading MCP tools.
|
||||
mcp_tools = []
|
||||
if include_mcp:
|
||||
try:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.cache import get_cached_mcp_tools
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
if extensions_config.get_enabled_mcp_servers():
|
||||
mcp_tools = get_cached_mcp_tools()
|
||||
if mcp_tools:
|
||||
logger.info(f"Using {len(mcp_tools)} cached MCP tool(s)")
|
||||
except ImportError:
|
||||
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached MCP tools: {e}")
|
||||
|
||||
# Conditionally add tools based on config
|
||||
builtin_tools = BUILTIN_TOOLS.copy()
|
||||
|
||||
# Add subagent tools only if enabled via runtime parameter
|
||||
if subagent_enabled:
|
||||
builtin_tools.extend(SUBAGENT_TOOLS)
|
||||
logger.info("Including subagent tools (task)")
|
||||
|
||||
# If no model_name specified, use the first model (default)
|
||||
if model_name is None and config.models:
|
||||
model_name = config.models[0].name
|
||||
|
||||
# Add view_image_tool only if the model supports vision
|
||||
model_config = config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
builtin_tools.append(view_image_tool)
|
||||
logger.info(f"Including view_image_tool for model '{model_name}' (supports_vision=True)")
|
||||
|
||||
return loaded_tools + builtin_tools + mcp_tools
|
||||
47
backend/packages/harness/deerflow/utils/file_conversion.py
Normal file
47
backend/packages/harness/deerflow/utils/file_conversion.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""File conversion utilities.
|
||||
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown using markitdown.
|
||||
No FastAPI or HTTP dependencies — pure utility functions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# File extensions that should be converted to markdown
|
||||
CONVERTIBLE_EXTENSIONS = {
|
||||
".pdf",
|
||||
".ppt",
|
||||
".pptx",
|
||||
".xls",
|
||||
".xlsx",
|
||||
".doc",
|
||||
".docx",
|
||||
}
|
||||
|
||||
|
||||
async def convert_file_to_markdown(file_path: Path) -> Path | None:
|
||||
"""Convert a file to markdown using markitdown.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to convert.
|
||||
|
||||
Returns:
|
||||
Path to the markdown file if conversion was successful, None otherwise.
|
||||
"""
|
||||
try:
|
||||
from markitdown import MarkItDown
|
||||
|
||||
md = MarkItDown()
|
||||
result = md.convert(str(file_path))
|
||||
|
||||
# Save as .md file with same name
|
||||
md_path = file_path.with_suffix(".md")
|
||||
md_path.write_text(result.text_content, encoding="utf-8")
|
||||
|
||||
logger.info(f"Converted {file_path.name} to markdown: {md_path.name}")
|
||||
return md_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert {file_path.name} to markdown: {e}")
|
||||
return None
|
||||
139
backend/packages/harness/deerflow/utils/network.py
Normal file
139
backend/packages/harness/deerflow/utils/network.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Thread-safe network utilities."""
|
||||
|
||||
import socket
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class PortAllocator:
|
||||
"""Thread-safe port allocator that prevents port conflicts in concurrent environments.
|
||||
|
||||
This class maintains a set of reserved ports and uses a lock to ensure that
|
||||
port allocation is atomic. Once a port is allocated, it remains reserved until
|
||||
explicitly released.
|
||||
|
||||
Usage:
|
||||
allocator = PortAllocator()
|
||||
|
||||
# Option 1: Manual allocation and release
|
||||
port = allocator.allocate(start_port=8080)
|
||||
try:
|
||||
# Use the port...
|
||||
finally:
|
||||
allocator.release(port)
|
||||
|
||||
# Option 2: Context manager (recommended)
|
||||
with allocator.allocate_context(start_port=8080) as port:
|
||||
# Use the port...
|
||||
# Port is automatically released when exiting the context
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._reserved_ports: set[int] = set()
|
||||
|
||||
def _is_port_available(self, port: int) -> bool:
|
||||
"""Check if a port is available for binding.
|
||||
|
||||
Args:
|
||||
port: The port number to check.
|
||||
|
||||
Returns:
|
||||
True if the port is available, False otherwise.
|
||||
"""
|
||||
if port in self._reserved_ports:
|
||||
return False
|
||||
|
||||
# Bind to 0.0.0.0 (wildcard) rather than localhost so that the check
|
||||
# mirrors exactly what Docker does. Docker binds to 0.0.0.0:PORT;
|
||||
# checking only 127.0.0.1 can falsely report a port as available even
|
||||
# when Docker already occupies it on the wildcard address.
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("0.0.0.0", port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def allocate(self, start_port: int = 8080, max_range: int = 100) -> int:
|
||||
"""Allocate an available port in a thread-safe manner.
|
||||
|
||||
This method is thread-safe. It finds an available port, marks it as reserved,
|
||||
and returns it. The port remains reserved until release() is called.
|
||||
|
||||
Args:
|
||||
start_port: The port number to start searching from.
|
||||
max_range: Maximum number of ports to search.
|
||||
|
||||
Returns:
|
||||
An available port number.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available port is found in the specified range.
|
||||
"""
|
||||
with self._lock:
|
||||
for port in range(start_port, start_port + max_range):
|
||||
if self._is_port_available(port):
|
||||
self._reserved_ports.add(port)
|
||||
return port
|
||||
|
||||
raise RuntimeError(f"No available port found in range {start_port}-{start_port + max_range}")
|
||||
|
||||
def release(self, port: int) -> None:
|
||||
"""Release a previously allocated port.
|
||||
|
||||
Args:
|
||||
port: The port number to release.
|
||||
"""
|
||||
with self._lock:
|
||||
self._reserved_ports.discard(port)
|
||||
|
||||
@contextmanager
|
||||
def allocate_context(self, start_port: int = 8080, max_range: int = 100):
|
||||
"""Context manager for port allocation with automatic release.
|
||||
|
||||
Args:
|
||||
start_port: The port number to start searching from.
|
||||
max_range: Maximum number of ports to search.
|
||||
|
||||
Yields:
|
||||
An available port number.
|
||||
"""
|
||||
port = self.allocate(start_port, max_range)
|
||||
try:
|
||||
yield port
|
||||
finally:
|
||||
self.release(port)
|
||||
|
||||
|
||||
# Global port allocator instance for shared use across the application
|
||||
_global_port_allocator = PortAllocator()
|
||||
|
||||
|
||||
def get_free_port(start_port: int = 8080, max_range: int = 100) -> int:
|
||||
"""Get a free port in a thread-safe manner.
|
||||
|
||||
This function uses a global port allocator to ensure that concurrent calls
|
||||
don't return the same port. The port is marked as reserved until release_port()
|
||||
is called.
|
||||
|
||||
Args:
|
||||
start_port: The port number to start searching from.
|
||||
max_range: Maximum number of ports to search.
|
||||
|
||||
Returns:
|
||||
An available port number.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available port is found in the specified range.
|
||||
"""
|
||||
return _global_port_allocator.allocate(start_port, max_range)
|
||||
|
||||
|
||||
def release_port(port: int) -> None:
|
||||
"""Release a previously allocated port.
|
||||
|
||||
Args:
|
||||
port: The port number to release.
|
||||
"""
|
||||
_global_port_allocator.release(port)
|
||||
83
backend/packages/harness/deerflow/utils/readability.py
Normal file
83
backend/packages/harness/deerflow/utils/readability.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from markdownify import markdownify as md
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Article:
|
||||
url: str
|
||||
|
||||
def __init__(self, title: str, html_content: str):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
|
||||
def to_markdown(self, including_title: bool = True) -> str:
|
||||
markdown = ""
|
||||
if including_title:
|
||||
markdown += f"# {self.title}\n\n"
|
||||
|
||||
if self.html_content is None or not str(self.html_content).strip():
|
||||
markdown += "*No content available*\n"
|
||||
else:
|
||||
markdown += md(self.html_content)
|
||||
|
||||
return markdown
|
||||
|
||||
def to_message(self) -> list[dict]:
|
||||
image_pattern = r"!\[.*?\]\((.*?)\)"
|
||||
|
||||
content: list[dict[str, str]] = []
|
||||
markdown = self.to_markdown()
|
||||
|
||||
if not markdown or not markdown.strip():
|
||||
return [{"type": "text", "text": "No content available"}]
|
||||
|
||||
parts = re.split(image_pattern, markdown)
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if i % 2 == 1:
|
||||
image_url = urljoin(self.url, part.strip())
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
else:
|
||||
text_part = part.strip()
|
||||
if text_part:
|
||||
content.append({"type": "text", "text": text_part})
|
||||
|
||||
# If after processing all parts, content is still empty, provide a fallback message.
|
||||
if not content:
|
||||
content = [{"type": "text", "text": "No content available"}]
|
||||
|
||||
return content
|
||||
|
||||
|
||||
class ReadabilityExtractor:
|
||||
def extract_article(self, html: str) -> Article:
|
||||
try:
|
||||
article = simple_json_from_html_string(html, use_readability=True)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
stderr = getattr(exc, "stderr", None)
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode(errors="replace")
|
||||
stderr_info = f"; stderr={stderr.strip()}" if isinstance(stderr, str) and stderr.strip() else ""
|
||||
logger.warning(
|
||||
"Readability.js extraction failed with %s%s; falling back to pure-Python extraction",
|
||||
type(exc).__name__,
|
||||
stderr_info,
|
||||
exc_info=True,
|
||||
)
|
||||
article = simple_json_from_html_string(html, use_readability=False)
|
||||
|
||||
html_content = article.get("content")
|
||||
if not html_content or not str(html_content).strip():
|
||||
html_content = "No content could be extracted from this page"
|
||||
|
||||
title = article.get("title")
|
||||
if not title or not str(title).strip():
|
||||
title = "Untitled"
|
||||
|
||||
return Article(title=title, html_content=html_content)
|
||||
Reference in New Issue
Block a user