mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-22 21:54:45 +08:00
feat: add clarification feature (#13)
This commit is contained in:
177
backend/src/agents/middlewares/clarification_middleware.py
Normal file
177
backend/src/agents/middlewares/clarification_middleware.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.graph import END
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
class ClarificationMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
|
||||
|
||||
When the model calls the `ask_clarification` tool, this middleware:
|
||||
1. Intercepts the tool call before execution
|
||||
2. Extracts the clarification question and metadata
|
||||
3. Formats a user-friendly message
|
||||
4. Returns a Command that interrupts execution and presents the question
|
||||
5. Waits for user response before continuing
|
||||
|
||||
This replaces the tool-based approach where clarification continued the conversation flow.
|
||||
"""
|
||||
|
||||
state_schema = ClarificationMiddlewareState
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text contains Chinese characters
|
||||
"""
|
||||
return any('\u4e00' <= char <= '\u9fff' for char in text)
|
||||
|
||||
def _format_clarification_message(self, args: dict) -> str:
|
||||
"""Format the clarification arguments into a user-friendly message.
|
||||
|
||||
Args:
|
||||
args: The tool call arguments containing clarification details
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
question = args.get("question", "")
|
||||
clarification_type = args.get("clarification_type", "missing_info")
|
||||
context = args.get("context")
|
||||
options = args.get("options", [])
|
||||
|
||||
# Type-specific icons
|
||||
type_icons = {
|
||||
"missing_info": "❓",
|
||||
"ambiguous_requirement": "🤔",
|
||||
"approach_choice": "🔀",
|
||||
"risk_confirmation": "⚠️",
|
||||
"suggestion": "💡",
|
||||
}
|
||||
|
||||
icon = type_icons.get(clarification_type, "❓")
|
||||
|
||||
# Build the message naturally
|
||||
message_parts = []
|
||||
|
||||
# Add icon and question together for a more natural flow
|
||||
if context:
|
||||
# If there's context, present it first as background
|
||||
message_parts.append(f"{icon} {context}")
|
||||
message_parts.append(f"\n{question}")
|
||||
else:
|
||||
# Just the question with icon
|
||||
message_parts.append(f"{icon} {question}")
|
||||
|
||||
# Add options in a cleaner format
|
||||
if options and len(options) > 0:
|
||||
message_parts.append("") # blank line for spacing
|
||||
for i, option in enumerate(options, 1):
|
||||
message_parts.append(f" {i}. {option}")
|
||||
|
||||
return "\n".join(message_parts)
|
||||
|
||||
def _handle_clarification(self, request: ToolCallRequest) -> Command:
|
||||
"""Handle clarification request and return command to interrupt execution.
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Extract clarification arguments
|
||||
args = request.tool_call.get("args", {})
|
||||
question = args.get("question", "")
|
||||
|
||||
print("[ClarificationMiddleware] Intercepted clarification request")
|
||||
print(f"[ClarificationMiddleware] Question: {question}")
|
||||
|
||||
# Format the clarification message
|
||||
formatted_message = self._format_clarification_message(args)
|
||||
|
||||
# Get the tool call ID
|
||||
tool_call_id = request.tool_call.get("id", "")
|
||||
|
||||
# Create a ToolMessage with the formatted question
|
||||
# This will be added to the message history
|
||||
tool_message = ToolMessage(
|
||||
content=formatted_message,
|
||||
tool_call_id=tool_call_id,
|
||||
name="ask_clarification",
|
||||
)
|
||||
|
||||
ai_response_message = AIMessage(content=formatted_message)
|
||||
|
||||
# Return a Command that:
|
||||
# 1. Adds the formatted tool message (keeping the AI message intact)
|
||||
# 2. Interrupts execution by going to __end__
|
||||
# Note: We don't modify the AI message to preserve all fields (reasoning_content, tool_calls, etc.)
|
||||
# This is especially important for thinking mode where reasoning_content is required
|
||||
|
||||
# Return Command to add the tool message and interrupt
|
||||
return Command(
|
||||
update={"messages": [tool_message, ai_response_message]},
|
||||
goto=END,
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (async version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler (async)
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return await handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
Reference in New Issue
Block a user