feat(guardrails): add pre-tool-call authorization middleware with pluggable providers (#1240)

Add GuardrailMiddleware that evaluates every tool call before execution.
Three provider options: built-in AllowlistProvider (zero deps), OAP passport
providers (open standard), or custom providers loaded by class path.

- GuardrailProvider protocol with GuardrailRequest/Decision dataclasses
- GuardrailMiddleware (AgentMiddleware, position 5 in chain)
- AllowlistProvider for simple deny/allow by tool name
- GuardrailsConfig (Pydantic singleton, loaded from config.yaml)
- 25 tests covering allow/deny, fail-closed/open, async, GraphBubbleUp
- Comprehensive docs at backend/docs/GUARDRAILS.md

Closes #1213

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Uchi Uchibeke
2026-03-23 06:07:33 -04:00
committed by GitHub
parent fe75cb35ca
commit a29134d7c9
11 changed files with 1041 additions and 7 deletions

View File

@@ -90,6 +90,31 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = get_guardrails_config()
if guardrails_config.enabled and guardrails_config.provider:
import inspect
from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.reflection import resolve_variable
provider_cls = resolve_variable(guardrails_config.provider.use)
provider_kwargs = dict(guardrails_config.provider.config) if guardrails_config.provider.config else {}
# Pass framework hint if the provider accepts it (e.g. for config discovery).
# Built-in providers like AllowlistProvider don't need it, so only inject
# when the constructor accepts 'framework' or '**kwargs'.
if "framework" not in provider_kwargs:
try:
sig = inspect.signature(provider_cls.__init__)
if "framework" in sig.parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
provider_kwargs["framework"] = "deerflow"
except (ValueError, TypeError):
pass
provider = provider_cls(**provider_kwargs)
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
middlewares.append(ToolErrorHandlingMiddleware())
return middlewares

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import load_guardrails_config_from_dict
from deerflow.config.memory_config import load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
@@ -107,6 +108,10 @@ class AppConfig(BaseModel):
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])

View File

@@ -0,0 +1,48 @@
"""Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field
class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider."""
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
class GuardrailsConfig(BaseModel):
"""Configuration for pre-tool-call authorization.
When enabled, every tool call passes through the configured provider
before execution. The provider receives tool name, arguments, and the
agent's passport reference, and returns an allow/deny decision.
"""
enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None

View File

@@ -0,0 +1,14 @@
"""Pre-tool-call authorization middleware."""
from deerflow.guardrails.builtin import AllowlistProvider
from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
__all__ = [
"AllowlistProvider",
"GuardrailDecision",
"GuardrailMiddleware",
"GuardrailProvider",
"GuardrailReason",
"GuardrailRequest",
]

View File

@@ -0,0 +1,23 @@
"""Built-in guardrail providers that ship with DeerFlow."""
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
class AllowlistProvider:
"""Simple allowlist/denylist provider. No external dependencies."""
name = "allowlist"
def __init__(self, *, allowed_tools: list[str] | None = None, denied_tools: list[str] | None = None):
self._allowed = set(allowed_tools) if allowed_tools else None
self._denied = set(denied_tools) if denied_tools else set()
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
if self._allowed is not None and request.tool_name not in self._allowed:
return GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.tool_not_allowed", message=f"tool '{request.tool_name}' not in allowlist")])
if request.tool_name in self._denied:
return GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.tool_not_allowed", message=f"tool '{request.tool_name}' is denied")])
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
return self.evaluate(request)

View File

@@ -0,0 +1,98 @@
"""GuardrailMiddleware - evaluates tool calls against a GuardrailProvider before execution."""
import logging
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
logger = logging.getLogger(__name__)
class GuardrailMiddleware(AgentMiddleware[AgentState]):
"""Evaluate tool calls against a GuardrailProvider before execution.
Denied calls return an error ToolMessage so the agent can adapt.
If the provider raises, behavior depends on fail_closed:
- True (default): block the call
- False: allow it through with a warning
"""
def __init__(self, provider: GuardrailProvider, *, fail_closed: bool = True, passport: str | None = None):
self.provider = provider
self.fail_closed = fail_closed
self.passport = passport
def _build_request(self, request: ToolCallRequest) -> GuardrailRequest:
return GuardrailRequest(
tool_name=str(request.tool_call.get("name", "")),
tool_input=request.tool_call.get("args", {}),
agent_id=self.passport,
timestamp=datetime.now(UTC).isoformat(),
)
def _build_denied_message(self, request: ToolCallRequest, decision: GuardrailDecision) -> ToolMessage:
tool_name = str(request.tool_call.get("name", "unknown_tool"))
tool_call_id = str(request.tool_call.get("id", "missing_id"))
reason_text = decision.reasons[0].message if decision.reasons else "blocked by guardrail policy"
reason_code = decision.reasons[0].code if decision.reasons else "oap.denied"
return ToolMessage(
content=f"Guardrail denied: tool '{tool_name}' was blocked ({reason_code}). Reason: {reason_text}. Choose an alternative approach.",
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
gr = self._build_request(request)
try:
decision = self.provider.evaluate(gr)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception:
logger.exception("Guardrail provider error (sync)")
if self.fail_closed:
decision = GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.evaluator_error", message="guardrail provider error (fail-closed)")])
else:
return handler(request)
if not decision.allow:
logger.warning("Guardrail denied: tool=%s policy=%s code=%s", gr.tool_name, decision.policy_id, decision.reasons[0].code if decision.reasons else "unknown")
return self._build_denied_message(request, decision)
return handler(request)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
gr = self._build_request(request)
try:
decision = await self.provider.aevaluate(gr)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception:
logger.exception("Guardrail provider error (async)")
if self.fail_closed:
decision = GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.evaluator_error", message="guardrail provider error (fail-closed)")])
else:
return await handler(request)
if not decision.allow:
logger.warning("Guardrail denied: tool=%s policy=%s code=%s", gr.tool_name, decision.policy_id, decision.reasons[0].code if decision.reasons else "unknown")
return self._build_denied_message(request, decision)
return await handler(request)

View File

@@ -0,0 +1,56 @@
"""GuardrailProvider protocol and data structures for pre-tool-call authorization."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable
@dataclass
class GuardrailRequest:
"""Context passed to the provider for each tool call."""
tool_name: str
tool_input: dict[str, Any]
agent_id: str | None = None
thread_id: str | None = None
is_subagent: bool = False
timestamp: str = ""
@dataclass
class GuardrailReason:
"""Structured reason for an allow/deny decision (OAP reason object)."""
code: str
message: str = ""
@dataclass
class GuardrailDecision:
"""Provider's allow/deny verdict (aligned with OAP Decision object)."""
allow: bool
reasons: list[GuardrailReason] = field(default_factory=list)
policy_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class GuardrailProvider(Protocol):
"""Contract for pluggable tool-call authorization.
Any class with these methods works - no base class required.
Providers are loaded by class path via resolve_variable(),
the same mechanism DeerFlow uses for models, tools, and sandbox.
"""
name: str
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
"""Evaluate whether a tool call should proceed."""
...
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
"""Async variant."""
...