refactor: split backend into harness (deerflow.*) and app (app.*) (#1131)

* refactor: extract shared utils to break harness→app cross-layer imports

Move _validate_skill_frontmatter to src/skills/validation.py and
CONVERTIBLE_EXTENSIONS + convert_file_to_markdown to src/utils/file_conversion.py.
This eliminates the two reverse dependencies from client.py (harness layer)
into gateway/routers/ (app layer), preparing for the harness/app package split.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: split backend/src into harness (deerflow.*) and app (app.*)

Physically split the monolithic backend/src/ package into two layers:

- **Harness** (`packages/harness/deerflow/`): publishable agent framework
  package with import prefix `deerflow.*`. Contains agents, sandbox, tools,
  models, MCP, skills, config, and all core infrastructure.

- **App** (`app/`): unpublished application code with import prefix `app.*`.
  Contains gateway (FastAPI REST API) and channels (IM integrations).

Key changes:
- Move 13 harness modules to packages/harness/deerflow/ via git mv
- Move gateway + channels to app/ via git mv
- Rename all imports: src.* → deerflow.* (harness) / app.* (app layer)
- Set up uv workspace with deerflow-harness as workspace member
- Update langgraph.json, config.example.yaml, all scripts, Docker files
- Add build-system (hatchling) to harness pyproject.toml
- Add PYTHONPATH=. to gateway startup commands for app.* resolution
- Update ruff.toml with known-first-party for import sorting
- Update all documentation to reflect new directory structure

Boundary rule enforced: harness code never imports from app.
All 429 tests pass. Lint clean.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: add harness→app boundary check test and update docs

Add test_harness_boundary.py that scans all Python files in
packages/harness/deerflow/ and fails if any `from app.*` or
`import app.*` statement is found. This enforces the architectural
rule that the harness layer never depends on the app layer.

Update CLAUDE.md to document the harness/app split architecture,
import conventions, and the boundary enforcement test.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: add config versioning with auto-upgrade on startup

When config.example.yaml schema changes, developers' local config.yaml
files can silently become outdated. This adds a config_version field and
auto-upgrade mechanism so breaking changes (like src.* → deerflow.*
renames) are applied automatically before services start.

- Add config_version: 1 to config.example.yaml
- Add startup version check warning in AppConfig.from_file()
- Add scripts/config-upgrade.sh with migration registry for value replacements
- Add `make config-upgrade` target
- Auto-run config-upgrade in serve.sh and start-daemon.sh before starting services
- Add config error hints in service failure messages

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix comments

* fix: update src.* import in test_sandbox_tools_security to deerflow.*

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: handle empty config and search parent dirs for config.example.yaml

Address Copilot review comments on PR #1131:
- Guard against yaml.safe_load() returning None for empty config files
- Search parent directories for config.example.yaml instead of only
  looking next to config.yaml, fixing detection in common setups

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: correct skills root path depth and config_version type coercion

- loader.py: fix get_skills_root_path() to use 5 parent levels (was 3)
  after harness split, file lives at packages/harness/deerflow/skills/
  so parent×3 resolved to backend/packages/harness/ instead of backend/
- app_config.py: coerce config_version to int() before comparison in
  _check_config_version() to prevent TypeError when YAML stores value
  as string (e.g. config_version: "1")
- tests: add regression tests for both fixes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: update test imports from src.* to deerflow.*/app.* after harness refactor

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-03-14 22:55:52 +08:00
committed by GitHub
parent 9b49a80dda
commit 76803b826f
198 changed files with 1786 additions and 941 deletions

View File

@@ -0,0 +1,5 @@
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
from .lead_agent import make_lead_agent
from .thread_state import SandboxState, ThreadState
__all__ = ["make_lead_agent", "SandboxState", "ThreadState", "get_checkpointer", "reset_checkpointer", "make_checkpointer"]

View File

@@ -0,0 +1,9 @@
from .async_provider import make_checkpointer
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
__all__ = [
"get_checkpointer",
"reset_checkpointer",
"checkpointer_context",
"make_checkpointer",
]

View File

@@ -0,0 +1,109 @@
"""Async checkpointer factory.
Provides an **async context manager** for long-running async servers that need
proper resource cleanup.
Supported backends: memory, sqlite, postgres.
Usage (e.g. FastAPI lifespan)::
from deerflow.agents.checkpointer.async_provider import make_checkpointer
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer # InMemorySaver if not configured
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
"""
from __future__ import annotations
import contextlib
import logging
from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.agents.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
SQLITE_INSTALL,
_resolve_sqlite_conn_str,
)
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Async factory
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs and tears down a checkpointer."""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
import pathlib
conn_str = _resolve_sqlite_conn_str(config.connection_string or "store.db")
# Only create parent directories for real filesystem paths
if conn_str != ":memory:" and not conn_str.startswith("file:"):
pathlib.Path(conn_str).parent.mkdir(parents=True, exist_ok=True)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Public async context manager
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit — no global state::
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
async with _async_checkpointer(config.checkpointer) as saver:
yield saver

View File

@@ -0,0 +1,201 @@
"""Sync checkpointer factory.
Provides a **sync singleton** and a **sync context manager** for LangGraph
graph compilation and CLI tools.
Supported backends: memory, sqlite, postgres.
Usage::
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
# Singleton — reused across calls, closed on process exit
cp = get_checkpointer()
# One-shot — fresh connection, closed on block exit
with checkpointer_context() as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
"""
from __future__ import annotations
import contextlib
import logging
from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.paths import resolve_path
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Error message constants — imported by aio.provider too
# ---------------------------------------------------------------------------
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
# Sync factory
# ---------------------------------------------------------------------------
def _resolve_sqlite_conn_str(raw: str) -> str:
"""Return a SQLite connection string ready for use with ``SqliteSaver``.
SQLite special strings (``":memory:"`` and ``file:`` URIs) are returned
unchanged. Plain filesystem paths — relative or absolute — are resolved
to an absolute string via :func:`resolve_path`.
"""
if raw == ":memory:" or raw.startswith("file:"):
return raw
return str(resolve_path(raw))
@contextlib.contextmanager
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
"""Context manager that creates and tears down a sync checkpointer.
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
underlying connections or pools is handled by higher-level helpers in
this module (such as the singleton factory or context manager); this
function does not return a separate cleanup callback.
"""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = _resolve_sqlite_conn_str(config.connection_string or "store.db")
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
with PostgresSaver.from_conn_string(config.connection_string) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Sync singleton
# ---------------------------------------------------------------------------
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer() -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Raises:
ImportError: If the required package for the configured backend is not installed.
ValueError: If ``connection_string`` is missing for a backend that requires it.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
if _app_config is None:
# Only load config if it hasn't been initialized yet
# In tests, config may be set directly via set_checkpointer_config()
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected
# Tests will set config directly via set_checkpointer_config()
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
def reset_checkpointer() -> None:
"""Reset the sync singleton, forcing recreation on the next call.
Closes any open backend connections and clears the cached instance.
Useful in tests or after a configuration change.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
# ---------------------------------------------------------------------------
# Sync context manager
# ---------------------------------------------------------------------------
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context() as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
yield saver

View File

@@ -0,0 +1,3 @@
from .agent import make_lead_agent
__all__ = ["make_lead_agent"]

View File

@@ -0,0 +1,334 @@
import logging
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import get_app_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
if requested_model_name and app_config.get_model_config(requested_model_name):
return requested_model_name
if requested_model_name and requested_model_name != default_model_name:
logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.")
return default_model_name
def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
if not config.enabled:
return None
# Prepare trigger parameter
trigger = None
if config.trigger is not None:
if isinstance(config.trigger, list):
trigger = [t.to_tuple() for t in config.trigger]
else:
trigger = config.trigger.to_tuple()
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
if config.model_name:
model = config.model_name
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
# Prepare kwargs
kwargs = {
"model": model,
"trigger": trigger,
"keep": keep,
}
if config.trim_tokens_to_summarize is not None:
kwargs["trim_tokens_to_summarize"] = config.trim_tokens_to_summarize
if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt
return SummarizationMiddleware(**kwargs)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
"""Create and configure the TodoList middleware.
Args:
is_plan_mode: Whether to enable plan mode with TodoList middleware.
Returns:
TodoMiddleware instance if plan mode is enabled, None otherwise.
"""
if not is_plan_mode:
return None
# Custom prompts matching DeerFlow's style
system_prompt = """
<todo_list_system>
You have access to the `write_todos` tool to help you manage and track complex multi-step objectives.
**CRITICAL RULES:**
- Mark todos as completed IMMEDIATELY after finishing each step - do NOT batch completions
- Keep EXACTLY ONE task as `in_progress` at any time (unless tasks can run in parallel)
- Update the todo list in REAL-TIME as you work - this gives users visibility into your progress
- DO NOT use this tool for simple tasks (< 3 steps) - just complete them directly
**When to Use:**
This tool is designed for complex objectives that require systematic tracking:
- Complex multi-step tasks requiring 3+ distinct steps
- Non-trivial tasks needing careful planning and execution
- User explicitly requests a todo list
- User provides multiple tasks (numbered or comma-separated list)
- The plan may need revisions based on intermediate results
**When NOT to Use:**
- Single, straightforward tasks
- Trivial tasks (< 3 steps)
- Purely conversational or informational requests
- Simple tool calls where the approach is obvious
**Best Practices:**
- Break down complex tasks into smaller, actionable steps
- Use clear, descriptive task names
- Remove tasks that become irrelevant
- Add new tasks discovered during implementation
- Don't be afraid to revise the todo list as you learn more
**Task Management:**
Writing todos takes time and tokens - use it when helpful for managing complex problems, not for simple requests.
</todo_list_system>
"""
tool_description = """Use this tool to create and manage a structured task list for complex work sessions.
**IMPORTANT: Only use this tool for complex tasks (3+ steps). For simple requests, just do the work directly.**
## When to Use
Use this tool in these scenarios:
1. **Complex multi-step tasks**: When a task requires 3 or more distinct steps or actions
2. **Non-trivial tasks**: Tasks requiring careful planning or multiple operations
3. **User explicitly requests todo list**: When the user directly asks you to track tasks
4. **Multiple tasks**: When users provide a list of things to be done
5. **Dynamic planning**: When the plan may need updates based on intermediate results
## When NOT to Use
Skip this tool when:
1. The task is straightforward and takes less than 3 steps
2. The task is trivial and tracking provides no benefit
3. The task is purely conversational or informational
4. It's clear what needs to be done and you can just do it
## How to Use
1. **Starting a task**: Mark it as `in_progress` BEFORE beginning work
2. **Completing a task**: Mark it as `completed` IMMEDIATELY after finishing
3. **Updating the list**: Add new tasks, remove irrelevant ones, or update descriptions as needed
4. **Multiple updates**: You can make several updates at once (e.g., complete one task and start the next)
## Task States
- `pending`: Task not yet started
- `in_progress`: Currently working on (can have multiple if tasks run in parallel)
- `completed`: Task finished successfully
## Task Completion Requirements
**CRITICAL: Only mark a task as completed when you have FULLY accomplished it.**
Never mark a task as completed if:
- There are unresolved issues or errors
- Work is partial or incomplete
- You encountered blockers preventing completion
- You couldn't find necessary resources or dependencies
- Quality standards haven't been met
If blocked, keep the task as `in_progress` and create a new task describing what needs to be resolved.
## Best Practices
- Create specific, actionable items
- Break complex tasks into smaller, manageable steps
- Use clear, descriptive task names
- Update task status in real-time as you work
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
- Remove tasks that are no longer relevant
- **IMPORTANT**: When you write the todo list, mark your first task(s) as `in_progress` immediately
- **IMPORTANT**: Unless all tasks are completed, always have at least one task `in_progress` to show progress
Being proactive with task management demonstrates thoroughness and ensures all requirements are completed successfully.
**Remember**: If you only need a few tool calls to complete a task and it's clear what to do, it's better to just do the task directly and NOT use this tool at all.
"""
return TodoMiddleware(system_prompt=system_prompt, tool_description=tool_description)
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
# DanglingToolCallMiddleware patches missing ToolMessages before model sees the history
# SummarizationMiddleware should be early to reduce context before other processing
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
# TitleMiddleware generates title after first exchange
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None):
"""Build middleware chain based on runtime configuration.
Args:
config: Runtime configuration containing configurable options like is_plan_mode.
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
Returns:
List of middleware instances.
"""
middlewares = build_lead_runtime_middlewares(lazy_init=True)
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
# Add TodoList middleware if plan mode is enabled
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
if todo_list_middleware is not None:
middlewares.append(todo_list_middleware)
# Add TitleMiddleware
middlewares.append(TitleMiddleware())
# Add MemoryMiddleware (after TitleMiddleware)
middlewares.append(MemoryMiddleware(agent_name=agent_name))
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
if subagent_enabled:
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware())
# ClarificationMiddleware should always be last
middlewares.append(ClarificationMiddleware())
return middlewares
def make_lead_agent(config: RunnableConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
cfg = config.get("configurable", {})
thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None)
requested_model_name: str | None = cfg.get("model_name") or cfg.get("model")
is_plan_mode = cfg.get("is_plan_mode", False)
subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
is_bootstrap = cfg.get("is_bootstrap", False)
agent_name = cfg.get("agent_name")
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model or fallback to global/default model resolution
agent_model_name = agent_config.model if agent_config and agent_config.model else _resolve_model_name()
# Final model name resolution with request override, then agent config, then global default
model_name = requested_model_name or agent_model_name
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
if thinking_enabled and not model_config.supports_thinking:
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
thinking_enabled = False
logger.info(
"Create Agent(%s) -> thinking_enabled: %s, reasoning_effort: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
agent_name or "default",
thinking_enabled,
reasoning_effort,
model_name,
is_plan_mode,
subagent_enabled,
max_concurrent_subagents,
)
# Inject run metadata for LangSmith trace tagging
if "metadata" not in config:
config["metadata"] = {}
config["metadata"].update(
{
"agent_name": agent_name or "default",
"model_name": model_name or "default",
"thinking_enabled": thinking_enabled,
"reasoning_effort": reasoning_effort,
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
}
)
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
system_prompt = apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"]))
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=system_prompt,
state_schema=ThreadState,
)
# Default lead agent (unchanged behavior)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
state_schema=ThreadState,
)

View File

@@ -0,0 +1,409 @@
from datetime import datetime
from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
Returns:
Formatted subagent section string.
"""
n = max_concurrent
return f"""<subagent_system>
**🚀 SUBAGENT MODE ACTIVE - DECOMPOSE, DELEGATE, SYNTHESIZE**
You are running with subagent capabilities enabled. Your role is to be a **task orchestrator**:
1. **DECOMPOSE**: Break complex tasks into parallel sub-tasks
2. **DELEGATE**: Launch multiple subagents simultaneously using parallel `task` calls
3. **SYNTHESIZE**: Collect and integrate results into a coherent answer
**CORE PRINCIPLE: Complex tasks should be decomposed and distributed across multiple subagents for parallel execution.**
**⛔ HARD CONCURRENCY LIMIT: MAXIMUM {n} `task` CALLS PER RESPONSE. THIS IS NOT OPTIONAL.**
- Each response, you may include **at most {n}** `task` tool calls. Any excess calls are **silently discarded** by the system — you will lose that work.
- **Before launching subagents, you MUST count your sub-tasks in your thinking:**
- If count ≤ {n}: Launch all in this response.
- If count > {n}: **Pick the {n} most important/foundational sub-tasks for this turn.** Save the rest for the next turn.
- **Multi-batch execution** (for >{n} sub-tasks):
- Turn 1: Launch sub-tasks 1-{n} in parallel → wait for results
- Turn 2: Launch next batch in parallel → wait for results
- ... continue until all sub-tasks are complete
- Final turn: Synthesize ALL results into a coherent answer
- **Example thinking pattern**: "I identified 6 sub-tasks. Since the limit is {n} per turn, I will launch the first {n} now, and the rest in the next turn."
**Available Subagents:**
- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.
- **bash**: For command execution (git, build, test, deploy operations)
**Your Orchestration Strategy:**
✅ **DECOMPOSE + PARALLEL EXECUTION (Preferred Approach):**
For complex queries, break them down into focused sub-tasks and execute in parallel batches (max {n} per turn):
**Example 1: "Why is Tencent's stock price declining?" (3 sub-tasks → 1 batch)**
→ Turn 1: Launch 3 subagents in parallel:
- Subagent 1: Recent financial reports, earnings data, and revenue trends
- Subagent 2: Negative news, controversies, and regulatory issues
- Subagent 3: Industry trends, competitor performance, and market sentiment
→ Turn 2: Synthesize results
**Example 2: "Compare 5 cloud providers" (5 sub-tasks → multi-batch)**
→ Turn 1: Launch {n} subagents in parallel (first batch)
→ Turn 2: Launch remaining subagents in parallel
→ Final turn: Synthesize ALL results into comprehensive comparison
**Example 3: "Refactor the authentication system"**
→ Turn 1: Launch 3 subagents in parallel:
- Subagent 1: Analyze current auth implementation and technical debt
- Subagent 2: Research best practices and security patterns
- Subagent 3: Review related tests, documentation, and vulnerabilities
→ Turn 2: Synthesize results
✅ **USE Parallel Subagents (max {n} per turn) when:**
- **Complex research questions**: Requires multiple information sources or perspectives
- **Multi-aspect analysis**: Task has several independent dimensions to explore
- **Large codebases**: Need to analyze different parts simultaneously
- **Comprehensive investigations**: Questions requiring thorough coverage from multiple angles
❌ **DO NOT use subagents (execute directly) when:**
- **Task cannot be decomposed**: If you can't break it into 2+ meaningful parallel sub-tasks, execute directly
- **Ultra-simple actions**: Read one file, quick edits, single commands
- **Need immediate clarification**: Must ask user before proceeding
- **Meta conversation**: Questions about conversation history
- **Sequential dependencies**: Each step depends on previous results (do steps yourself sequentially)
**CRITICAL WORKFLOW** (STRICTLY follow this before EVERY action):
1. **COUNT**: In your thinking, list all sub-tasks and count them explicitly: "I have N sub-tasks"
2. **PLAN BATCHES**: If N > {n}, explicitly plan which sub-tasks go in which batch:
- "Batch 1 (this turn): first {n} sub-tasks"
- "Batch 2 (next turn): next batch of sub-tasks"
3. **EXECUTE**: Launch ONLY the current batch (max {n} `task` calls). Do NOT launch sub-tasks from future batches.
4. **REPEAT**: After results return, launch the next batch. Continue until all batches complete.
5. **SYNTHESIZE**: After ALL batches are done, synthesize all results.
6. **Cannot decompose** → Execute directly using available tools (bash, read_file, web_search, etc.)
**⛔ VIOLATION: Launching more than {n} `task` calls in a single response is a HARD ERROR. The system WILL discard excess calls and you WILL lose work. Always batch.**
**Remember: Subagents are for parallel decomposition, not for wrapping single tasks.**
**How It Works:**
- The task tool runs subagents asynchronously in the background
- The backend automatically polls for completion (you don't need to poll)
- The tool call will block until the subagent completes its work
- Once complete, the result is returned to you directly
**Usage Example 1 - Single Batch (≤{n} sub-tasks):**
```python
# User asks: "Why is Tencent's stock price declining?"
# Thinking: 3 sub-tasks → fits in 1 batch
# Turn 1: Launch 3 subagents in parallel
task(description="Tencent financial data", prompt="...", subagent_type="general-purpose")
task(description="Tencent news & regulation", prompt="...", subagent_type="general-purpose")
task(description="Industry & market trends", prompt="...", subagent_type="general-purpose")
# All 3 run in parallel → synthesize results
```
**Usage Example 2 - Multiple Batches (>{n} sub-tasks):**
```python
# User asks: "Compare AWS, Azure, GCP, Alibaba Cloud, and Oracle Cloud"
# Thinking: 5 sub-tasks → need multiple batches (max {n} per batch)
# Turn 1: Launch first batch of {n}
task(description="AWS analysis", prompt="...", subagent_type="general-purpose")
task(description="Azure analysis", prompt="...", subagent_type="general-purpose")
task(description="GCP analysis", prompt="...", subagent_type="general-purpose")
# Turn 2: Launch remaining batch (after first batch completes)
task(description="Alibaba Cloud analysis", prompt="...", subagent_type="general-purpose")
task(description="Oracle Cloud analysis", prompt="...", subagent_type="general-purpose")
# Turn 3: Synthesize ALL results from both batches
```
**Counter-Example - Direct Execution (NO subagents):**
```python
# User asks: "Run the tests"
# Thinking: Cannot decompose into parallel sub-tasks
# → Execute directly
bash("npm test") # Direct execution, not task()
```
**CRITICAL**:
- **Max {n} `task` calls per turn** - the system enforces this, excess calls are discarded
- Only use `task` when you can launch 2+ subagents in parallel
- Single task = No value from subagents = Execute directly
- For >{n} sub-tasks, use sequential batches of {n} across multiple turns
</subagent_system>"""
SYSTEM_PROMPT_TEMPLATE = """
<role>
You are {agent_name}, an open-source super agent.
</role>
{soul}
{memory_context}
<thinking_style>
- Think concisely and strategically about the user's request BEFORE taking action
- Break down the task: What is clear? What is ambiguous? What is missing?
- **PRIORITY CHECK: If anything is unclear, missing, or has multiple interpretations, you MUST ask for clarification FIRST - do NOT proceed with work**
{subagent_thinking}- Never write down your full final answer or report in thinking process, but only outline
- CRITICAL: After thinking, you MUST provide your actual response to the user. Thinking is for planning, the response is for delivery.
- Your response must contain the actual answer, not just a reference to what you thought about
</thinking_style>
<clarification_system>
**WORKFLOW PRIORITY: CLARIFY → PLAN → ACT**
1. **FIRST**: Analyze the request in your thinking - identify what's unclear, missing, or ambiguous
2. **SECOND**: If clarification is needed, call `ask_clarification` tool IMMEDIATELY - do NOT start working
3. **THIRD**: Only after all clarifications are resolved, proceed with planning and execution
**CRITICAL RULE: Clarification ALWAYS comes BEFORE action. Never start working and clarify mid-execution.**
**MANDATORY Clarification Scenarios - You MUST call ask_clarification BEFORE starting work when:**
1. **Missing Information** (`missing_info`): Required details not provided
- Example: User says "create a web scraper" but doesn't specify the target website
- Example: "Deploy the app" without specifying environment
- **REQUIRED ACTION**: Call ask_clarification to get the missing information
2. **Ambiguous Requirements** (`ambiguous_requirement`): Multiple valid interpretations exist
- Example: "Optimize the code" could mean performance, readability, or memory usage
- Example: "Make it better" is unclear what aspect to improve
- **REQUIRED ACTION**: Call ask_clarification to clarify the exact requirement
3. **Approach Choices** (`approach_choice`): Several valid approaches exist
- Example: "Add authentication" could use JWT, OAuth, session-based, or API keys
- Example: "Store data" could use database, files, cache, etc.
- **REQUIRED ACTION**: Call ask_clarification to let user choose the approach
4. **Risky Operations** (`risk_confirmation`): Destructive actions need confirmation
- Example: Deleting files, modifying production configs, database operations
- Example: Overwriting existing code or data
- **REQUIRED ACTION**: Call ask_clarification to get explicit confirmation
5. **Suggestions** (`suggestion`): You have a recommendation but want approval
- Example: "I recommend refactoring this code. Should I proceed?"
- **REQUIRED ACTION**: Call ask_clarification to get approval
**STRICT ENFORCEMENT:**
- ❌ DO NOT start working and then ask for clarification mid-execution - clarify FIRST
- ❌ DO NOT skip clarification for "efficiency" - accuracy matters more than speed
- ❌ DO NOT make assumptions when information is missing - ALWAYS ask
- ❌ DO NOT proceed with guesses - STOP and call ask_clarification first
- ✅ Analyze the request in thinking → Identify unclear aspects → Ask BEFORE any action
- ✅ If you identify the need for clarification in your thinking, you MUST call the tool IMMEDIATELY
- ✅ After calling ask_clarification, execution will be interrupted automatically
- ✅ Wait for user response - do NOT continue with assumptions
**How to Use:**
```python
ask_clarification(
question="Your specific question here?",
clarification_type="missing_info", # or other type
context="Why you need this information", # optional but recommended
options=["option1", "option2"] # optional, for choices
)
```
**Example:**
User: "Deploy the application"
You (thinking): Missing environment info - I MUST ask for clarification
You (action): ask_clarification(
question="Which environment should I deploy to?",
clarification_type="approach_choice",
context="I need to know the target environment for proper configuration",
options=["development", "staging", "production"]
)
[Execution stops - wait for user response]
User: "staging"
You: "Deploying to staging..." [proceed]
</clarification_system>
{skills_section}
{subagent_section}
<working_directory existed="true">
- User uploads: `/mnt/user-data/uploads` - Files uploaded by the user (automatically listed in context)
- User workspace: `/mnt/user-data/workspace` - Working directory for temporary files
- Output files: `/mnt/user-data/outputs` - Final deliverables must be saved here
**File Management:**
- Uploaded files are automatically listed in the <uploaded_files> section before each request
- Use `read_file` tool to read uploaded files using their paths from the list
- For PDF, PPT, Excel, and Word files, converted Markdown versions (*.md) are available alongside originals
- All temporary work happens in `/mnt/user-data/workspace`
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
</working_directory>
<response_style>
- Clear and Concise: Avoid over-formatting unless requested
- Natural Tone: Use paragraphs and prose, not bullet points by default
- Action-Oriented: Focus on delivering results, not explaining processes
</response_style>
<citations>
- When to Use: After web_search, include citations if applicable
- Format: Use Markdown link format `[citation:TITLE](URL)`
- Example:
```markdown
The key AI trends for 2026 include enhanced reasoning capabilities and multimodal integration
[citation:AI Trends 2026](https://techcrunch.com/ai-trends).
Recent breakthroughs in language models have also accelerated progress
[citation:OpenAI Research](https://openai.com/research).
```
</citations>
<critical_reminders>
- **Clarification First**: ALWAYS clarify unclear/missing/ambiguous requirements BEFORE starting work - never assume or guess
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
- Progressive Loading: Load resources incrementally as referenced in skills
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `![Image Description](image_path)\n\n` or "```mermaid" to display images in response or Markdown files
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
- Language Consistency: Keep using the same language as user's
- Always Respond: Your thinking is internal. You MUST always provide a visible response to the user after thinking.
</critical_reminders>
"""
def _get_memory_context(agent_name: str | None = None) -> str:
"""Get memory context for injection into system prompt.
Args:
agent_name: If provided, loads per-agent memory. If None, loads global memory.
Returns:
Formatted memory context string wrapped in XML tags, or empty string if disabled.
"""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
if not config.enabled or not config.injection_enabled:
return ""
memory_data = get_memory_data(agent_name)
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip():
return ""
return f"""<memory>
{memory_content}
</memory>
"""
except Exception as e:
print(f"Failed to load memory context: {e}")
return ""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
"""Generate the skills prompt section with available skills list.
Returns the <skill_system>...</skill_system> block listing all enabled skills,
suitable for injection into any agent's system prompt.
"""
skills = load_skills(enabled_only=True)
try:
from deerflow.config import get_app_config
config = get_app_config()
container_base_path = config.skills.container_path
except Exception:
container_base_path = "/mnt/skills"
if not skills:
return ""
if available_skills is not None:
skills = [skill for skill in skills if skill.name in available_skills]
skill_items = "\n".join(
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
)
skills_list = f"<available_skills>\n{skill_items}\n</available_skills>"
return f"""<skill_system>
You have access to skills that provide optimized workflows for specific tasks. Each skill contains best practices, frameworks, and references to additional resources.
**Progressive Loading Pattern:**
1. When a user query matches a skill's use case, immediately call `read_file` on the skill's main file using the path attribute provided in the skill tag below
2. Read and understand the skill's workflow and instructions
3. The skill file contains references to external resources under the same folder
4. Load referenced resources only when needed during execution
5. Follow the skill's instructions precisely
**Skills are located at:** {container_base_path}
{skills_list}
</skill_system>"""
def get_agent_soul(agent_name: str | None) -> str:
# Append SOUL.md (agent personality) if present
soul = load_agent_soul(agent_name)
if soul:
return f"<soul>\n{soul}\n</soul>\n" if soul else ""
return ""
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
# Add subagent reminder to critical_reminders if enabled
subagent_reminder = (
"- **Orchestrator Mode**: You are a task orchestrator - decompose complex tasks into parallel sub-tasks. "
f"**HARD LIMIT: max {n} `task` calls per response.** "
f"If >{n} sub-tasks, split into sequential batches of ≤{n}. Synthesize after ALL batches complete.\n"
if subagent_enabled
else ""
)
# Add subagent thinking guidance if enabled
subagent_thinking = (
"- **DECOMPOSITION CHECK: Can this task be broken into 2+ parallel sub-tasks? If YES, COUNT them. "
f"If count > {n}, you MUST plan batches of ≤{n} and only launch the FIRST batch now. "
f"NEVER launch more than {n} `task` calls in one response.**\n"
if subagent_enabled
else ""
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
# Format the prompt with dynamic skills and memory
prompt = SYSTEM_PROMPT_TEMPLATE.format(
agent_name=agent_name or "DeerFlow 2.0",
soul=get_agent_soul(agent_name),
skills_section=skills_section,
memory_context=memory_context,
subagent_section=subagent_section,
subagent_reminder=subagent_reminder,
subagent_thinking=subagent_thinking,
)
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"

View File

@@ -0,0 +1,44 @@
"""Memory module for DeerFlow.
This module provides a global memory mechanism that:
- Stores user context and conversation history in memory.json
- Uses LLM to summarize and extract facts from conversations
- Injects relevant memory into system prompts for personalized responses
"""
from deerflow.agents.memory.prompt import (
FACT_EXTRACTION_PROMPT,
MEMORY_UPDATE_PROMPT,
format_conversation_for_update,
format_memory_for_injection,
)
from deerflow.agents.memory.queue import (
ConversationContext,
MemoryUpdateQueue,
get_memory_queue,
reset_memory_queue,
)
from deerflow.agents.memory.updater import (
MemoryUpdater,
get_memory_data,
reload_memory_data,
update_memory_from_conversation,
)
__all__ = [
# Prompt utilities
"MEMORY_UPDATE_PROMPT",
"FACT_EXTRACTION_PROMPT",
"format_memory_for_injection",
"format_conversation_for_update",
# Queue
"ConversationContext",
"MemoryUpdateQueue",
"get_memory_queue",
"reset_memory_queue",
# Updater
"MemoryUpdater",
"get_memory_data",
"reload_memory_data",
"update_memory_from_conversation",
]

View File

@@ -0,0 +1,339 @@
"""Prompt templates for memory update and injection."""
import math
import re
from typing import Any
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
# Prompt template for updating memory based on conversation
MEMORY_UPDATE_PROMPT = """You are a memory management system. Your task is to analyze a conversation and update the user's memory profile.
Current Memory State:
<current_memory>
{current_memory}
</current_memory>
New Conversation to Process:
<conversation>
{conversation}
</conversation>
Instructions:
1. Analyze the conversation for important information about the user
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
3. Update the memory sections as needed following the detailed length guidelines below
Memory Section Guidelines:
**User Context** (Current state - concise summaries):
- workContext: Professional role, company, key projects, main technologies (2-3 sentences)
Example: Core contributor, project names with metrics (16k+ stars), technical stack
- personalContext: Languages, communication preferences, key interests (1-2 sentences)
Example: Bilingual capabilities, specific interest areas, expertise domains
- topOfMind: Multiple ongoing focus areas and priorities (3-5 sentences, detailed paragraph)
Example: Primary project work, parallel technical investigations, ongoing learning/tracking
Include: Active implementation work, troubleshooting issues, market/research interests
Note: This captures SEVERAL concurrent focus areas, not just one task
**History** (Temporal context - rich paragraphs):
- recentMonths: Detailed summary of recent activities (4-6 sentences or 1-2 paragraphs)
Timeline: Last 1-3 months of interactions
Include: Technologies explored, projects worked on, problems solved, interests demonstrated
- earlierContext: Important historical patterns (3-5 sentences or 1 paragraph)
Timeline: 3-12 months ago
Include: Past projects, learning journeys, established patterns
- longTermBackground: Persistent background and foundational context (2-4 sentences)
Timeline: Overall/foundational information
Include: Core expertise, longstanding interests, fundamental working style
**Facts Extraction**:
- Extract specific, quantifiable details (e.g., "16k+ GitHub stars", "200+ datasets")
- Include proper nouns (company names, project names, technology names)
- Preserve technical terminology and version numbers
- Categories:
* preference: Tools, styles, approaches user prefers/dislikes
* knowledge: Specific expertise, technologies mastered, domain knowledge
* context: Background facts (job title, projects, locations, languages)
* behavior: Working patterns, communication habits, problem-solving approaches
* goal: Stated objectives, learning targets, project ambitions
- Confidence levels:
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
* 0.7-0.8: Strongly implied from actions/discussions
* 0.5-0.6: Inferred patterns (use sparingly, only for clear patterns)
**What Goes Where**:
- workContext: Current job, active projects, primary tech stack
- personalContext: Languages, personality, interests outside direct work tasks
- topOfMind: Multiple ongoing priorities and focus areas user cares about recently (gets updated most frequently)
Should capture 3-5 concurrent themes: main work, side explorations, learning/tracking interests
- recentMonths: Detailed account of recent technical explorations and work
- earlierContext: Patterns from slightly older interactions still relevant
- longTermBackground: Unchanging foundational facts about the user
**Multilingual Content**:
- Preserve original language for proper nouns and company names
- Keep technical terms in their original form (DeepSeek, LangGraph, etc.)
- Note language capabilities in personalContext
Output Format (JSON):
{{
"user": {{
"workContext": {{ "summary": "...", "shouldUpdate": true/false }},
"personalContext": {{ "summary": "...", "shouldUpdate": true/false }},
"topOfMind": {{ "summary": "...", "shouldUpdate": true/false }}
}},
"history": {{
"recentMonths": {{ "summary": "...", "shouldUpdate": true/false }},
"earlierContext": {{ "summary": "...", "shouldUpdate": true/false }},
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
}},
"newFacts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
],
"factsToRemove": ["fact_id_1", "fact_id_2"]
}}
Important Rules:
- Only set shouldUpdate=true if there's meaningful new information
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
- Include specific metrics, version numbers, and proper nouns in facts
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
- Remove facts that are contradicted by new information
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
Keep 3-5 concurrent focus themes that are still active and relevant
- For history sections, integrate new information chronologically into appropriate time period
- Preserve technical accuracy - keep exact names of technologies, companies, projects
- Focus on information useful for future interactions and personalization
- IMPORTANT: Do NOT record file upload events in memory. Uploaded files are
session-specific and ephemeral — they will not be accessible in future sessions.
Recording upload events causes confusion in subsequent conversations.
Return ONLY valid JSON, no explanation or markdown."""
# Prompt template for extracting facts from a single message
FACT_EXTRACTION_PROMPT = """Extract factual information about the user from this message.
Message:
{message}
Extract facts in this JSON format:
{{
"facts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
]
}}
Categories:
- preference: User preferences (likes/dislikes, styles, tools)
- knowledge: User's expertise or knowledge areas
- context: Background context (location, job, projects)
- behavior: Behavioral patterns
- goal: User's goals or objectives
Rules:
- Only extract clear, specific facts
- Confidence should reflect certainty (explicit statement = 0.9+, implied = 0.6-0.8)
- Skip vague or temporary information
Return ONLY valid JSON."""
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
"""Count tokens in text using tiktoken.
Args:
text: The text to count tokens for.
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
Returns:
The number of tokens in the text.
"""
if not TIKTOKEN_AVAILABLE:
# Fallback to character-based estimation if tiktoken is not available
return len(text) // 4
try:
encoding = tiktoken.get_encoding(encoding_name)
return len(encoding.encode(text))
except Exception:
# Fallback to character-based estimation on error
return len(text) // 4
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
"""Coerce a confidence-like value to a bounded float in [0, 1].
Non-finite values (NaN, inf, -inf) are treated as invalid and fall back
to the default before clamping, preventing them from dominating ranking.
The ``default`` parameter is assumed to be a finite value.
"""
try:
confidence = float(value)
except (TypeError, ValueError):
return max(0.0, min(1.0, default))
if not math.isfinite(confidence):
return max(0.0, min(1.0, default))
return max(0.0, min(1.0, confidence))
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
"""Format memory data for injection into system prompt.
Args:
memory_data: The memory data dictionary.
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
Returns:
Formatted memory string for system prompt injection.
"""
if not memory_data:
return ""
sections = []
# Format user context
user_data = memory_data.get("user", {})
if user_data:
user_sections = []
work_ctx = user_data.get("workContext", {})
if work_ctx.get("summary"):
user_sections.append(f"Work: {work_ctx['summary']}")
personal_ctx = user_data.get("personalContext", {})
if personal_ctx.get("summary"):
user_sections.append(f"Personal: {personal_ctx['summary']}")
top_of_mind = user_data.get("topOfMind", {})
if top_of_mind.get("summary"):
user_sections.append(f"Current Focus: {top_of_mind['summary']}")
if user_sections:
sections.append("User Context:\n" + "\n".join(f"- {s}" for s in user_sections))
# Format history
history_data = memory_data.get("history", {})
if history_data:
history_sections = []
recent = history_data.get("recentMonths", {})
if recent.get("summary"):
history_sections.append(f"Recent: {recent['summary']}")
earlier = history_data.get("earlierContext", {})
if earlier.get("summary"):
history_sections.append(f"Earlier: {earlier['summary']}")
if history_sections:
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
# Format facts (sorted by confidence; include as many as token budget allows)
facts_data = memory_data.get("facts", [])
if isinstance(facts_data, list) and facts_data:
ranked_facts = sorted(
(
f
for f in facts_data
if isinstance(f, dict)
and isinstance(f.get("content"), str)
and f.get("content").strip()
),
key=lambda fact: _coerce_confidence(fact.get("confidence"), default=0.0),
reverse=True,
)
# Compute token count for existing sections once, then account
# incrementally for each fact line to avoid full-string re-tokenization.
base_text = "\n\n".join(sections)
base_tokens = _count_tokens(base_text) if base_text else 0
# Account for the separator between existing sections and the facts section.
facts_header = "Facts:\n"
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
running_tokens = base_tokens + separator_tokens
fact_lines: list[str] = []
for fact in ranked_facts:
content_value = fact.get("content")
if not isinstance(content_value, str):
continue
content = content_value.strip()
if not content:
continue
category = str(fact.get("category", "context")).strip() or "context"
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
line = f"- [{category} | {confidence:.2f}] {content}"
# Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line
line_tokens = _count_tokens(line_text)
if running_tokens + line_tokens <= max_tokens:
fact_lines.append(line)
running_tokens += line_tokens
else:
break
if fact_lines:
sections.append("Facts:\n" + "\n".join(fact_lines))
if not sections:
return ""
result = "\n\n".join(sections)
# Use accurate token counting with tiktoken
token_count = _count_tokens(result)
if token_count > max_tokens:
# Truncate to fit within token limit
# Estimate characters to remove based on token ratio
char_per_token = len(result) / token_count
target_chars = int(max_tokens * char_per_token * 0.95) # 95% to leave margin
result = result[:target_chars] + "\n..."
return result
def format_conversation_for_update(messages: list[Any]) -> str:
"""Format conversation messages for memory update prompt.
Args:
messages: List of conversation messages.
Returns:
Formatted conversation string.
"""
lines = []
for msg in messages:
role = getattr(msg, "type", "unknown")
content = getattr(msg, "content", str(msg))
# Handle content that might be a list (multimodal)
if isinstance(content, list):
text_parts = [p.get("text", "") for p in content if isinstance(p, dict) and "text" in p]
content = " ".join(text_parts) if text_parts else str(content)
# Strip uploaded_files tags from human messages to avoid persisting
# ephemeral file path info into long-term memory. Skip the turn entirely
# when nothing remains after stripping (upload-only message).
if role == "human":
content = re.sub(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", "", str(content)).strip()
if not content:
continue
# Truncate very long messages
if len(str(content)) > 1000:
content = str(content)[:1000] + "..."
if role == "human":
lines.append(f"User: {content}")
elif role == "ai":
lines.append(f"Assistant: {content}")
return "\n\n".join(lines)

View File

@@ -0,0 +1,195 @@
"""Memory update queue with debounce mechanism."""
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from deerflow.config.memory_config import get_memory_config
@dataclass
class ConversationContext:
"""Context for a conversation to be processed for memory update."""
thread_id: str
messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
class MemoryUpdateQueue:
"""Queue for memory updates with debounce mechanism.
This queue collects conversation contexts and processes them after
a configurable debounce period. Multiple conversations received within
the debounce window are batched together.
"""
def __init__(self):
"""Initialize the memory update queue."""
self._queue: list[ConversationContext] = []
self._lock = threading.Lock()
self._timer: threading.Timer | None = None
self._processing = False
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
"""Add a conversation to the update queue.
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
"""
config = get_memory_config()
if not config.enabled:
return
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
)
with self._lock:
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
# Reset or start the debounce timer
self._reset_timer()
print(f"Memory update queued for thread {thread_id}, queue size: {len(self._queue)}")
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
# Cancel existing timer if any
if self._timer is not None:
self._timer.cancel()
# Start new timer
self._timer = threading.Timer(
config.debounce_seconds,
self._process_queue,
)
self._timer.daemon = True
self._timer.start()
print(f"Memory update timer set for {config.debounce_seconds}s")
def _process_queue(self) -> None:
"""Process all queued conversation contexts."""
# Import here to avoid circular dependency
from deerflow.agents.memory.updater import MemoryUpdater
with self._lock:
if self._processing:
# Already processing, reschedule
self._reset_timer()
return
if not self._queue:
return
self._processing = True
contexts_to_process = self._queue.copy()
self._queue.clear()
self._timer = None
print(f"Processing {len(contexts_to_process)} queued memory updates")
try:
updater = MemoryUpdater()
for context in contexts_to_process:
try:
print(f"Updating memory for thread {context.thread_id}")
success = updater.update_memory(
messages=context.messages,
thread_id=context.thread_id,
agent_name=context.agent_name,
)
if success:
print(f"Memory updated successfully for thread {context.thread_id}")
else:
print(f"Memory update skipped/failed for thread {context.thread_id}")
except Exception as e:
print(f"Error updating memory for thread {context.thread_id}: {e}")
# Small delay between updates to avoid rate limiting
if len(contexts_to_process) > 1:
time.sleep(0.5)
finally:
with self._lock:
self._processing = False
def flush(self) -> None:
"""Force immediate processing of the queue.
This is useful for testing or graceful shutdown.
"""
with self._lock:
if self._timer is not None:
self._timer.cancel()
self._timer = None
self._process_queue()
def clear(self) -> None:
"""Clear the queue without processing.
This is useful for testing.
"""
with self._lock:
if self._timer is not None:
self._timer.cancel()
self._timer = None
self._queue.clear()
self._processing = False
@property
def pending_count(self) -> int:
"""Get the number of pending updates."""
with self._lock:
return len(self._queue)
@property
def is_processing(self) -> bool:
"""Check if the queue is currently being processed."""
with self._lock:
return self._processing
# Global singleton instance
_memory_queue: MemoryUpdateQueue | None = None
_queue_lock = threading.Lock()
def get_memory_queue() -> MemoryUpdateQueue:
"""Get the global memory update queue singleton.
Returns:
The memory update queue instance.
"""
global _memory_queue
with _queue_lock:
if _memory_queue is None:
_memory_queue = MemoryUpdateQueue()
return _memory_queue
def reset_memory_queue() -> None:
"""Reset the global memory queue.
This is useful for testing.
"""
global _memory_queue
with _queue_lock:
if _memory_queue is not None:
_memory_queue.clear()
_memory_queue = None

View File

@@ -0,0 +1,384 @@
"""Memory updater for reading, writing, and updating memory data."""
import json
import re
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
from deerflow.agents.memory.prompt import (
MEMORY_UPDATE_PROMPT,
format_conversation_for_update,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
def _get_memory_file_path(agent_name: str | None = None) -> Path:
"""Get the path to the memory file.
Args:
agent_name: If provided, returns the per-agent memory file path.
If None, returns the global memory file path.
Returns:
Path to the memory file.
"""
if agent_name is not None:
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
if config.storage_path:
p = Path(config.storage_path)
# Absolute path: use as-is; relative path: resolve against base_dir
return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file
def _create_empty_memory() -> dict[str, Any]:
"""Create an empty memory structure."""
return {
"version": "1.0",
"lastUpdated": datetime.utcnow().isoformat() + "Z",
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
"topOfMind": {"summary": "", "updatedAt": ""},
},
"history": {
"recentMonths": {"summary": "", "updatedAt": ""},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""},
},
"facts": [],
}
# Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime)
_memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Get the current memory data (cached with file modification time check).
The cache is automatically invalidated if the memory file has been modified
since the last load, ensuring fresh data is always returned.
Args:
agent_name: If provided, loads per-agent memory. If None, loads global memory.
Returns:
The memory data dictionary.
"""
file_path = _get_memory_file_path(agent_name)
# Get current file modification time
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
current_mtime = None
cached = _memory_cache.get(agent_name)
# Invalidate cache if file has been modified or doesn't exist
if cached is None or cached[1] != current_mtime:
memory_data = _load_memory_from_file(agent_name)
_memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return cached[0]
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation.
Args:
agent_name: If provided, reloads per-agent memory. If None, reloads global memory.
Returns:
The reloaded memory data dictionary.
"""
file_path = _get_memory_file_path(agent_name)
memory_data = _load_memory_from_file(agent_name)
try:
mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
mtime = None
_memory_cache[agent_name] = (memory_data, mtime)
return memory_data
def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data from file.
Args:
agent_name: If provided, loads per-agent memory file. If None, loads global.
Returns:
The memory data dictionary.
"""
file_path = _get_memory_file_path(agent_name)
if not file_path.exists():
return _create_empty_memory()
try:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
return data
except (json.JSONDecodeError, OSError) as e:
print(f"Failed to load memory file: {e}")
return _create_empty_memory()
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
_UPLOAD_SENTENCE_RE = re.compile(
r"[^.!?]*\b(?:"
r"upload(?:ed|ing)?(?:\s+\w+){0,3}\s+(?:file|files?|document|documents?|attachment|attachments?)"
r"|file\s+upload"
r"|/mnt/user-data/uploads/"
r"|<uploaded_files>"
r")[^.!?]*[.!?]?\s*",
re.IGNORECASE,
)
def _strip_upload_mentions_from_memory(memory_data: dict[str, Any]) -> dict[str, Any]:
"""Remove sentences about file uploads from all memory summaries and facts.
Uploaded files are session-scoped; persisting upload events in long-term
memory causes the agent to search for non-existent files in future sessions.
"""
# Scrub summaries in user/history sections
for section in ("user", "history"):
section_data = memory_data.get(section, {})
for _key, val in section_data.items():
if isinstance(val, dict) and "summary" in val:
cleaned = _UPLOAD_SENTENCE_RE.sub("", val["summary"]).strip()
cleaned = re.sub(r" +", " ", cleaned)
val["summary"] = cleaned
# Also remove any facts that describe upload events
facts = memory_data.get("facts", [])
if facts:
memory_data["facts"] = [f for f in facts if not _UPLOAD_SENTENCE_RE.search(f.get("content", ""))]
return memory_data
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data to file and update cache.
Args:
memory_data: The memory data to save.
agent_name: If provided, saves to per-agent memory file. If None, saves to global.
Returns:
True if successful, False otherwise.
"""
file_path = _get_memory_file_path(agent_name)
try:
# Ensure directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
# Update lastUpdated timestamp
memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z"
# Write atomically using temp file
temp_path = file_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False)
# Rename temp file to actual file (atomic on most systems)
temp_path.replace(file_path)
# Update cache and file modification time
try:
mtime = file_path.stat().st_mtime
except OSError:
mtime = None
_memory_cache[agent_name] = (memory_data, mtime)
print(f"Memory saved to {file_path}")
return True
except OSError as e:
print(f"Failed to save memory file: {e}")
return False
class MemoryUpdater:
"""Updates memory using LLM based on conversation context."""
def __init__(self, model_name: str | None = None):
"""Initialize the memory updater.
Args:
model_name: Optional model name to use. If None, uses config or default.
"""
self._model_name = model_name
def _get_model(self):
"""Get the model for memory updates."""
config = get_memory_config()
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
Returns:
True if update was successful, False otherwise.
"""
config = get_memory_config()
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = str(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return _save_memory_to_file(updated_memory, agent_name)
except json.JSONDecodeError as e:
print(f"Failed to parse LLM response for memory update: {e}")
return False
except Exception as e:
print(f"Memory update failed: {e}")
return False
def _apply_updates(
self,
current_memory: dict[str, Any],
update_data: dict[str, Any],
thread_id: str | None = None,
) -> dict[str, Any]:
"""Apply LLM-generated updates to memory.
Args:
current_memory: Current memory data.
update_data: Updates from LLM.
thread_id: Optional thread ID for tracking.
Returns:
Updated memory data.
"""
config = get_memory_config()
now = datetime.utcnow().isoformat() + "Z"
# Update user sections
user_updates = update_data.get("user", {})
for section in ["workContext", "personalContext", "topOfMind"]:
section_data = user_updates.get(section, {})
if section_data.get("shouldUpdate") and section_data.get("summary"):
current_memory["user"][section] = {
"summary": section_data["summary"],
"updatedAt": now,
}
# Update history sections
history_updates = update_data.get("history", {})
for section in ["recentMonths", "earlierContext", "longTermBackground"]:
section_data = history_updates.get(section, {})
if section_data.get("shouldUpdate") and section_data.get("summary"):
current_memory["history"][section] = {
"summary": section_data["summary"],
"updatedAt": now,
}
# Remove facts
facts_to_remove = set(update_data.get("factsToRemove", []))
if facts_to_remove:
current_memory["facts"] = [f for f in current_memory.get("facts", []) if f.get("id") not in facts_to_remove]
# Add new facts
new_facts = update_data.get("newFacts", [])
for fact in new_facts:
confidence = fact.get("confidence", 0.5)
if confidence >= config.fact_confidence_threshold:
fact_entry = {
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": fact.get("content", ""),
"category": fact.get("category", "context"),
"confidence": confidence,
"createdAt": now,
"source": thread_id or "unknown",
}
current_memory["facts"].append(fact_entry)
# Enforce max facts limit
if len(current_memory["facts"]) > config.max_facts:
# Sort by confidence and keep top ones
current_memory["facts"] = sorted(
current_memory["facts"],
key=lambda f: f.get("confidence", 0),
reverse=True,
)[: config.max_facts]
return current_memory
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
"""Convenience function to update memory from a conversation.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name)

View File

@@ -0,0 +1,173 @@
"""Middleware for intercepting clarification requests and presenting them to the user."""
from collections.abc import Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.graph import END
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
class ClarificationMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
When the model calls the `ask_clarification` tool, this middleware:
1. Intercepts the tool call before execution
2. Extracts the clarification question and metadata
3. Formats a user-friendly message
4. Returns a Command that interrupts execution and presents the question
5. Waits for user response before continuing
This replaces the tool-based approach where clarification continued the conversation flow.
"""
state_schema = ClarificationMiddlewareState
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters.
Args:
text: Text to check
Returns:
True if text contains Chinese characters
"""
return any("\u4e00" <= char <= "\u9fff" for char in text)
def _format_clarification_message(self, args: dict) -> str:
"""Format the clarification arguments into a user-friendly message.
Args:
args: The tool call arguments containing clarification details
Returns:
Formatted message string
"""
question = args.get("question", "")
clarification_type = args.get("clarification_type", "missing_info")
context = args.get("context")
options = args.get("options", [])
# Type-specific icons
type_icons = {
"missing_info": "",
"ambiguous_requirement": "🤔",
"approach_choice": "🔀",
"risk_confirmation": "⚠️",
"suggestion": "💡",
}
icon = type_icons.get(clarification_type, "")
# Build the message naturally
message_parts = []
# Add icon and question together for a more natural flow
if context:
# If there's context, present it first as background
message_parts.append(f"{icon} {context}")
message_parts.append(f"\n{question}")
else:
# Just the question with icon
message_parts.append(f"{icon} {question}")
# Add options in a cleaner format
if options and len(options) > 0:
message_parts.append("") # blank line for spacing
for i, option in enumerate(options, 1):
message_parts.append(f" {i}. {option}")
return "\n".join(message_parts)
def _handle_clarification(self, request: ToolCallRequest) -> Command:
"""Handle clarification request and return command to interrupt execution.
Args:
request: Tool call request
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Extract clarification arguments
args = request.tool_call.get("args", {})
question = args.get("question", "")
print("[ClarificationMiddleware] Intercepted clarification request")
print(f"[ClarificationMiddleware] Question: {question}")
# Format the clarification message
formatted_message = self._format_clarification_message(args)
# Get the tool call ID
tool_call_id = request.tool_call.get("id", "")
# Create a ToolMessage with the formatted question
# This will be added to the message history
tool_message = ToolMessage(
content=formatted_message,
tool_call_id=tool_call_id,
name="ask_clarification",
)
# Return a Command that:
# 1. Adds the formatted tool message
# 2. Interrupts execution by going to __end__
# Note: We don't add an extra AIMessage here - the frontend will detect
# and display ask_clarification tool messages directly
return Command(
update={"messages": [tool_message]},
goto=END,
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
Args:
request: Tool call request
handler: Original tool execution handler
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Check if this is an ask_clarification tool call
if request.tool_call.get("name") != "ask_clarification":
# Not a clarification call, execute normally
return handler(request)
return self._handle_clarification(request)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept ask_clarification tool calls and interrupt execution (async version).
Args:
request: Tool call request
handler: Original tool execution handler (async)
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Check if this is an ask_clarification tool call
if request.tool_call.get("name") != "ask_clarification":
# Not a clarification call, execute normally
return await handler(request)
return self._handle_clarification(request)

View File

@@ -0,0 +1,110 @@
"""Middleware to fix dangling tool calls in message history.
A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format.
This middleware intercepts the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator immediately after the
AIMessage that made the tool calls, ensuring correct message ordering.
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
at the correct positions (immediately after each dangling AIMessage), not appended
to the end of the message list as before_model + add_messages reducer would do.
"""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
logger = logging.getLogger(__name__)
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses immediately after the
offending AIMessage so the LLM receives a well-formed conversation.
"""
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
"""
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
# Check if any patching is needed
needs_patch = False
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
patched_ids: set[str] = set()
patch_count = 0
for msg in messages:
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
)
)
patched_ids.add(tc_id)
patch_count += 1
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return handler(request)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return await handler(request)

View File

@@ -0,0 +1,227 @@
"""Middleware to detect and break repetitive tool call loops.
P0 safety: prevents the agent from calling the same tool with the same
arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
"""
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import SystemMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + args).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# First normalize each tool call to a minimal (name, args) structure.
normalized: list[dict] = []
for tc in tool_calls:
normalized.append(
{
"name": tc.get("name", ""),
"args": tc.get("args", {}),
}
)
# Sort by both name and a deterministic serialization of args so that
# permutations of the same multiset of calls yield the same ordering.
normalized.sort(
key=lambda tc: (
tc["name"],
json.dumps(tc["args"], sort_keys=True, default=str),
)
)
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
_WARNING_MSG = (
"[LOOP DETECTED] You are repeating the same tool calls. "
"Stop calling tools and produce your final answer now. "
"If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = (
"[FORCED STOP] Repeated tool calls exceeded the safety limit. "
"Producing final answer with results collected so far."
)
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
hard_limit: Number of identical tool call sets before stripping
tool_calls entirely. Default: 5.
window_size: Size of the sliding window for tracking calls.
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
"""
def __init__(
self,
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id")
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
Must be called while holding self._lock.
"""
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Returns:
(warning_message_or_none, should_hard_stop)
"""
messages = state.get("messages", [])
if not messages:
return None, False
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None, False
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None, False
thread_id = self._get_thread_id(runtime)
call_hash = _hash_tool_calls(tool_calls)
with self._lock:
# Touch / create entry (move to end for LRU)
if thread_id in self._history:
self._history.move_to_end(thread_id)
else:
self._history[thread_id] = []
self._evict_if_needed()
history = self._history[thread_id]
history.append(call_hash)
if len(history) > self.window_size:
history[:] = history[-self.window_size:]
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _HARD_STOP_MSG, True
if count >= self.warn_threshold:
warned = self._warned[thread_id]
if call_hash not in warned:
warned.add(call_hash)
logger.warning(
"Repetitive tool calls detected — injecting warning",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _WARNING_MSG, False
# Warning already injected for this hash — suppress
return None, False
return None, False
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
stripped_msg = last_msg.model_copy(update={
"tool_calls": [],
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
})
return {"messages": [stripped_msg]}
if warning:
# Inject a system message warning the model
return {"messages": [SystemMessage(content=warning)]}
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()

View File

@@ -0,0 +1,149 @@
"""Middleware for memory mechanism."""
import re
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content = getattr(msg, "content", "")
if isinstance(content, list):
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
content_str = str(content)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
This middleware:
1. After each agent execution, queues the conversation for memory update
2. Only includes user inputs and final assistant responses (ignores tool calls)
3. The queue uses debouncing to batch multiple updates together
4. Memory is updated asynchronously via LLM summarization
"""
state_schema = MemoryMiddlewareState
def __init__(self, agent_name: str | None = None):
"""Initialize the MemoryMiddleware.
Args:
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
"""
super().__init__()
self._agent_name = agent_name
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
return None
# Get thread ID from runtime context
thread_id = runtime.context.get("thread_id")
if not thread_id:
print("MemoryMiddleware: No thread_id in context, skipping memory update")
return None
# Get messages from state
messages = state.get("messages", [])
if not messages:
print("MemoryMiddleware: No messages in state, skipping memory update")
return None
# Filter to only keep user inputs and final assistant responses
filtered_messages = _filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
user_messages = [m for m in filtered_messages if getattr(m, "type", None) == "human"]
assistant_messages = [m for m in filtered_messages if getattr(m, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return None
# Queue the filtered conversation for memory update
queue = get_memory_queue()
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
return None

View File

@@ -0,0 +1,75 @@
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
# Valid range for max_concurrent_subagents
MIN_SUBAGENT_LIMIT = 2
MAX_SUBAGENT_LIMIT = 4
def _clamp_subagent_limit(value: int) -> int:
"""Clamp subagent limit to valid range [2, 4]."""
return max(MIN_SUBAGENT_LIMIT, min(MAX_SUBAGENT_LIMIT, value))
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
"""Truncates excess 'task' tool calls from a single model response.
When an LLM generates more than max_concurrent parallel task tool calls
in one response, this middleware keeps only the first max_concurrent and
discards the rest. This is more reliable than prompt-based limits.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed.
Defaults to MAX_CONCURRENT_SUBAGENTS (3). Clamped to [2, 4].
"""
def __init__(self, max_concurrent: int = MAX_CONCURRENT_SUBAGENTS):
super().__init__()
self.max_concurrent = _clamp_subagent_limit(max_concurrent)
def _truncate_task_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None
# Count task tool calls
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
if len(task_indices) <= self.max_concurrent:
return None
# Build set of indices to drop (excess task calls beyond the limit)
indices_to_drop = set(task_indices[self.max_concurrent :])
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
dropped_count = len(indices_to_drop)
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)

View File

@@ -0,0 +1,90 @@
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import Paths, get_paths
class ThreadDataMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
thread_data: NotRequired[ThreadDataState | None]
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
"""Create thread data directories for each thread execution.
Creates the following directory structure:
- {base_dir}/threads/{thread_id}/user-data/workspace
- {base_dir}/threads/{thread_id}/user-data/uploads
- {base_dir}/threads/{thread_id}/user-data/outputs
Lifecycle Management:
- With lazy_init=True (default): Only compute paths, directories created on-demand
- With lazy_init=False: Eagerly create directories in before_agent()
"""
state_schema = ThreadDataMiddlewareState
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to Paths resolution.
lazy_init: If True, defer directory creation until needed.
If False, create directories eagerly in before_agent().
Default is True for optimal performance.
"""
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
self._lazy_init = lazy_init
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
return {
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with the created directory paths.
"""
self._paths.ensure_thread_dirs(thread_id)
return self._get_thread_paths(thread_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
thread_id = runtime.context.get("thread_id")
if thread_id is None:
raise ValueError("Thread ID is required in the context")
if self._lazy_init:
# Lazy initialization: only compute paths, don't create directories
paths = self._get_thread_paths(thread_id)
else:
# Eager initialization: create directories immediately
paths = self._create_thread_directories(thread_id)
print(f"Created thread data directories for thread {thread_id}")
return {
"thread_data": {
**paths,
}
}

View File

@@ -0,0 +1,93 @@
"""Middleware for automatic thread title generation."""
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
class TitleMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
title: NotRequired[str | None]
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
"""Automatically generate a title for the thread after the first user message."""
state_schema = TitleMiddlewareState
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
if not config.enabled:
return False
# Check if thread already has a title in state
if state.get("title"):
return False
# Check if this is the first turn (has at least one user message and one assistant response)
messages = state.get("messages", [])
if len(messages) < 2:
return False
# Count user and assistant messages
user_messages = [m for m in messages if m.type == "human"]
assistant_messages = [m for m in messages if m.type == "ai"]
# Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1
async def _generate_title(self, state: TitleMiddlewareState) -> str:
"""Generate a concise title based on the conversation."""
config = get_title_config()
messages = state.get("messages", [])
# Get first user message and first assistant response
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
# Ensure content is string (LangChain messages can have list content)
user_msg = str(user_msg_content) if user_msg_content else ""
assistant_msg = str(assistant_msg_content) if assistant_msg_content else ""
# Use a lightweight model to generate title
model = create_chat_model(thinking_enabled=False)
prompt = config.prompt_template.format(
max_words=config.max_words,
user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500],
)
try:
response = await model.ainvoke(prompt)
# Ensure response content is string
title_content = str(response.content) if response.content else ""
title = title_content.strip().strip('"').strip("'")
# Limit to max characters
return title[: config.max_chars] if len(title) > config.max_chars else title
except Exception as e:
print(f"Failed to generate title: {e}")
# Fallback: use first part of user message (by character count)
fallback_chars = min(config.max_chars, 50) # Use max_chars or 50, whichever is smaller
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
"""Generate and set thread title after the first agent response."""
if self._should_generate_title(state):
title = await self._generate_title(state)
print(f"Generated thread title: {title}")
# Store title in state (will be persisted by checkpointer if configured)
return {"title": title}
return None

View File

@@ -0,0 +1,100 @@
"""Middleware that extends TodoListMiddleware with context-loss detection.
When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list.
"""
from __future__ import annotations
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
for msg in messages:
if isinstance(msg, AIMessage) and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("name") == "write_todos":
return True
return False
def _reminder_in_messages(messages: list[Any]) -> bool:
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
for msg in messages:
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
return True
return False
def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string."""
lines: list[str] = []
for todo in todos:
status = todo.get("status", "pending")
content = todo.get("content", "")
lines.append(f"- [{status}] {content}")
return "\n".join(lines)
class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
When the original `write_todos` tool call has been truncated from the message
history (e.g., after summarization), the model loses awareness of the current
todo list. This middleware detects that gap in `before_model` / `abefore_model`
and injects a reminder message so the model can continue tracking progress.
"""
@override
def before_model(
self,
state: PlanningState,
runtime: Runtime, # noqa: ARG002
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos:
return None
messages = state.get("messages") or []
if _todos_in_messages(messages):
# write_todos is still visible in context — nothing to do.
return None
if _reminder_in_messages(messages):
# A reminder was already injected and hasn't been truncated yet.
return None
# The todo list exists in state but the original write_todos call is gone.
# Inject a reminder as a HumanMessage so the model stays aware.
formatted = _format_todos(todos)
reminder = HumanMessage(
name="todo_reminder",
content=(
"<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, "
"but it is still active. Here is the current state:\n\n"
f"{formatted}\n\n"
"Continue tracking and updating this todo list as you work. "
"Call `write_todos` whenever the status of any item changes.\n"
"</system_reminder>"
),
)
return {"messages": [reminder]}
@override
async def abefore_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)

View File

@@ -0,0 +1,112 @@
"""Tool error handling middleware and shared runtime middleware builders."""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Convert tool exceptions into error ToolMessages so the run can continue."""
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
tool_name = str(request.tool_call.get("name") or "unknown_tool")
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
detail = str(exc).strip() or exc.__class__.__name__
if len(detail) > 500:
detail = detail[:497] + "..."
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
def _build_runtime_middlewares(
*,
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
middlewares: list[AgentMiddleware] = [
ThreadDataMiddleware(lazy_init=lazy_init),
SandboxMiddleware(lazy_init=lazy_init),
]
if include_uploads:
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
middlewares.insert(1, UploadsMiddleware())
if include_dangling_tool_call_patch:
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(ToolErrorHandlingMiddleware())
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=False,
lazy_init=lazy_init,
)

View File

@@ -0,0 +1,204 @@
"""Middleware to inject uploaded files information into agent context."""
import logging
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
logger = logging.getLogger(__name__)
class UploadsMiddlewareState(AgentState):
"""State schema for uploads middleware."""
uploaded_files: NotRequired[list[dict] | None]
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
"""Middleware to inject uploaded files information into the agent context.
Reads file metadata from the current message's additional_kwargs.files
(set by the frontend after upload) and prepends an <uploaded_files> block
to the last human message so the model knows which files are available.
"""
state_schema = UploadsMiddlewareState
def __init__(self, base_dir: str | None = None):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to Paths resolution.
"""
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
"""Create a formatted message listing uploaded files.
Args:
new_files: Files uploaded in the current message.
historical_files: Files uploaded in previous messages.
Returns:
Formatted string inside <uploaded_files> tags.
"""
lines = ["<uploaded_files>"]
lines.append("The following files were uploaded in this message:")
lines.append("")
if new_files:
for file in new_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
else:
lines.append("(empty)")
if historical_files:
lines.append("The following files were uploaded in previous messages and are still available:")
lines.append("")
for file in historical_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
lines.append("</uploaded_files>")
return "\n".join(lines)
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
"""Extract file info from message additional_kwargs.files.
The frontend sends uploaded file metadata in additional_kwargs.files
after a successful upload. Each entry has: filename, size (bytes),
path (virtual path), status.
Args:
message: The human message to inspect.
uploads_dir: Physical uploads directory used to verify file existence.
When provided, entries whose files no longer exist are skipped.
Returns:
List of file dicts with virtual paths, or None if the field is absent or empty.
"""
kwargs_files = (message.additional_kwargs or {}).get("files")
if not isinstance(kwargs_files, list) or not kwargs_files:
return None
files = []
for f in kwargs_files:
if not isinstance(f, dict):
continue
filename = f.get("filename") or ""
if not filename or Path(filename).name != filename:
continue
if uploads_dir is not None and not (uploads_dir / filename).is_file():
continue
files.append(
{
"filename": filename,
"size": int(f.get("size") or 0),
"path": f"/mnt/user-data/uploads/{filename}",
"extension": Path(filename).suffix,
}
)
return files if files else None
@override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files.
Historical files are scanned from the thread's uploads directory,
excluding the new ones.
Prepends <uploaded_files> context to the last human message content.
The original additional_kwargs (including files metadata) is preserved
on the updated message so the frontend can read it from the stream.
Args:
state: Current agent state.
runtime: Runtime context containing thread_id.
Returns:
State updates including uploaded files list.
"""
messages = list(state.get("messages", []))
if not messages:
return None
last_message_index = len(messages) - 1
last_message = messages[last_message_index]
if not isinstance(last_message, HumanMessage):
return None
# Resolve uploads directory for existence checks
thread_id = runtime.context.get("thread_id")
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
# Collect historical files from the uploads directory (all except the new ones)
new_filenames = {f["filename"] for f in new_files}
historical_files: list[dict] = []
if uploads_dir and uploads_dir.exists():
for file_path in sorted(uploads_dir.iterdir()):
if file_path.is_file() and file_path.name not in new_filenames:
stat = file_path.stat()
historical_files.append(
{
"filename": file_path.name,
"size": stat.st_size,
"path": f"/mnt/user-data/uploads/{file_path.name}",
"extension": file_path.suffix,
}
)
if not new_files and not historical_files:
return None
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
# Create files message and prepend to the last human message content
files_message = self._create_files_message(new_files, historical_files)
# Extract original content - handle both string and list formats
original_content = ""
if isinstance(last_message.content, str):
original_content = last_message.content
elif isinstance(last_message.content, list):
text_parts = []
for block in last_message.content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
original_content = "\n".join(text_parts)
# Create new message with combined content.
# Preserve additional_kwargs (including files metadata) so the frontend
# can read structured file info from the streamed message.
updated_message = HumanMessage(
content=f"{files_message}\n\n{original_content}",
id=last_message.id,
additional_kwargs=last_message.additional_kwargs,
)
messages[last_message_index] = updated_message
return {
"uploaded_files": new_files,
"messages": messages,
}

View File

@@ -0,0 +1,221 @@
"""Middleware for injecting image details into conversation before LLM call."""
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ViewedImageData
class ViewImageMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
viewed_images: NotRequired[dict[str, ViewedImageData] | None]
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
"""Injects image details as a human message before LLM calls when view_image tools have completed.
This middleware:
1. Runs before each LLM call
2. Checks if the last assistant message contains view_image tool calls
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
5. Adds the message to state so the LLM can see and analyze the images
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
without requiring explicit user prompts to describe the images.
"""
state_schema = ViewImageMiddlewareState
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
"""Get the last assistant message from the message list.
Args:
messages: List of messages
Returns:
Last AIMessage or None if not found
"""
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return msg
return None
def _has_view_image_tool(self, message: AIMessage) -> bool:
"""Check if the assistant message contains view_image tool calls.
Args:
message: Assistant message to check
Returns:
True if message contains view_image tool calls
"""
if not hasattr(message, "tool_calls") or not message.tool_calls:
return False
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
"""Check if all tool calls in the assistant message have been completed.
Args:
messages: List of all messages
assistant_msg: The assistant message containing tool calls
Returns:
True if all tool calls have corresponding ToolMessages
"""
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
return False
# Get all tool call IDs from the assistant message
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
# Find the index of the assistant message
try:
assistant_idx = messages.index(assistant_msg)
except ValueError:
return False
# Get all ToolMessages after the assistant message
completed_tool_ids = set()
for msg in messages[assistant_idx + 1 :]:
if isinstance(msg, ToolMessage) and msg.tool_call_id:
completed_tool_ids.add(msg.tool_call_id)
# Check if all tool calls have been completed
return tool_call_ids.issubset(completed_tool_ids)
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
"""Create a formatted message with all viewed image details.
Args:
state: Current state containing viewed_images
Returns:
List of content blocks (text and images) for the HumanMessage
"""
viewed_images = state.get("viewed_images", {})
if not viewed_images:
return ["No images have been viewed."]
# Build the message with image information
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
for image_path, image_data in viewed_images.items():
mime_type = image_data.get("mime_type", "unknown")
base64_data = image_data.get("base64", "")
# Add text description
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
# Add the actual image data so LLM can "see" it
if base64_data:
content_blocks.append(
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
}
)
return content_blocks
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
"""Determine if we should inject an image details message.
Args:
state: Current state
Returns:
True if we should inject the message
"""
messages = state.get("messages", [])
if not messages:
return False
# Get the last assistant message
last_assistant_msg = self._get_last_assistant_message(messages)
if not last_assistant_msg:
return False
# Check if it has view_image tool calls
if not self._has_view_image_tool(last_assistant_msg):
return False
# Check if all tools have been completed
if not self._all_tools_completed(messages, last_assistant_msg):
return False
# Check if we've already added an image details message
# Look for a human message after the last assistant message that contains image details
assistant_idx = messages.index(last_assistant_msg)
for msg in messages[assistant_idx + 1 :]:
if isinstance(msg, HumanMessage):
content_str = str(msg.content)
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
# Already added, don't add again
return False
return True
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
"""Internal helper to inject image details message.
Args:
state: Current state
Returns:
State update with additional human message, or None if no update needed
"""
if not self._should_inject_image_message(state):
return None
# Create the image details message with text and image content
image_content = self._create_image_details_message(state)
# Create a new human message with mixed content (text + images)
human_msg = HumanMessage(content=image_content)
print("[ViewImageMiddleware] Injecting image details message with images before LLM call")
# Return state update with the new message
return {"messages": [human_msg]}
@override
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject image details message before LLM call if view_image tools have completed (sync version).
This runs before each LLM call, checking if the previous turn included view_image
tool calls that have all completed. If so, it injects a human message with the image
details so the LLM can see and analyze the images.
Args:
state: Current state
runtime: Runtime context (unused but required by interface)
Returns:
State update with additional human message, or None if no update needed
"""
return self._inject_image_message(state)
@override
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject image details message before LLM call if view_image tools have completed (async version).
This runs before each LLM call, checking if the previous turn included view_image
tool calls that have all completed. If so, it injects a human message with the image
details so the LLM can see and analyze the images.
Args:
state: Current state
runtime: Runtime context (unused but required by interface)
Returns:
State update with additional human message, or None if no update needed
"""
return self._inject_image_message(state)

View File

@@ -0,0 +1,55 @@
from typing import Annotated, NotRequired, TypedDict
from langchain.agents import AgentState
class SandboxState(TypedDict):
sandbox_id: NotRequired[str | None]
class ThreadDataState(TypedDict):
workspace_path: NotRequired[str | None]
uploads_path: NotRequired[str | None]
outputs_path: NotRequired[str | None]
class ViewedImageData(TypedDict):
base64: str
mime_type: str
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
"""Reducer for artifacts list - merges and deduplicates artifacts."""
if existing is None:
return new or []
if new is None:
return existing
# Use dict.fromkeys to deduplicate while preserving order
return list(dict.fromkeys(existing + new))
def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]:
"""Reducer for viewed_images dict - merges image dictionaries.
Special case: If new is an empty dict {}, it clears the existing images.
This allows middlewares to clear the viewed_images state after processing.
"""
if existing is None:
return new or {}
if new is None:
return existing
# Special case: empty dict means clear all viewed images
if len(new) == 0:
return {}
# Merge dictionaries, new values override existing ones for same keys
return {**existing, **new}
class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
title: NotRequired[str | None]
artifacts: Annotated[list[str], merge_artifacts]
todos: NotRequired[list | None]
uploaded_files: NotRequired[list[dict] | None]
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}