mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-22 13:44:46 +08:00
178 lines
6.2 KiB
Python
178 lines
6.2 KiB
Python
|
|
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||
|
|
|
||
|
|
from collections.abc import Callable
|
||
|
|
from typing import override
|
||
|
|
|
||
|
|
from langchain.agents import AgentState
|
||
|
|
from langchain.agents.middleware import AgentMiddleware
|
||
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
||
|
|
from langgraph.graph import END
|
||
|
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||
|
|
from langgraph.types import Command
|
||
|
|
|
||
|
|
|
||
|
|
class ClarificationMiddlewareState(AgentState):
|
||
|
|
"""Compatible with the `ThreadState` schema."""
|
||
|
|
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||
|
|
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
|
||
|
|
|
||
|
|
When the model calls the `ask_clarification` tool, this middleware:
|
||
|
|
1. Intercepts the tool call before execution
|
||
|
|
2. Extracts the clarification question and metadata
|
||
|
|
3. Formats a user-friendly message
|
||
|
|
4. Returns a Command that interrupts execution and presents the question
|
||
|
|
5. Waits for user response before continuing
|
||
|
|
|
||
|
|
This replaces the tool-based approach where clarification continued the conversation flow.
|
||
|
|
"""
|
||
|
|
|
||
|
|
state_schema = ClarificationMiddlewareState
|
||
|
|
|
||
|
|
def _is_chinese(self, text: str) -> bool:
|
||
|
|
"""Check if text contains Chinese characters.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: Text to check
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if text contains Chinese characters
|
||
|
|
"""
|
||
|
|
return any('\u4e00' <= char <= '\u9fff' for char in text)
|
||
|
|
|
||
|
|
def _format_clarification_message(self, args: dict) -> str:
|
||
|
|
"""Format the clarification arguments into a user-friendly message.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
args: The tool call arguments containing clarification details
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Formatted message string
|
||
|
|
"""
|
||
|
|
question = args.get("question", "")
|
||
|
|
clarification_type = args.get("clarification_type", "missing_info")
|
||
|
|
context = args.get("context")
|
||
|
|
options = args.get("options", [])
|
||
|
|
|
||
|
|
# Type-specific icons
|
||
|
|
type_icons = {
|
||
|
|
"missing_info": "❓",
|
||
|
|
"ambiguous_requirement": "🤔",
|
||
|
|
"approach_choice": "🔀",
|
||
|
|
"risk_confirmation": "⚠️",
|
||
|
|
"suggestion": "💡",
|
||
|
|
}
|
||
|
|
|
||
|
|
icon = type_icons.get(clarification_type, "❓")
|
||
|
|
|
||
|
|
# Build the message naturally
|
||
|
|
message_parts = []
|
||
|
|
|
||
|
|
# Add icon and question together for a more natural flow
|
||
|
|
if context:
|
||
|
|
# If there's context, present it first as background
|
||
|
|
message_parts.append(f"{icon} {context}")
|
||
|
|
message_parts.append(f"\n{question}")
|
||
|
|
else:
|
||
|
|
# Just the question with icon
|
||
|
|
message_parts.append(f"{icon} {question}")
|
||
|
|
|
||
|
|
# Add options in a cleaner format
|
||
|
|
if options and len(options) > 0:
|
||
|
|
message_parts.append("") # blank line for spacing
|
||
|
|
for i, option in enumerate(options, 1):
|
||
|
|
message_parts.append(f" {i}. {option}")
|
||
|
|
|
||
|
|
return "\n".join(message_parts)
|
||
|
|
|
||
|
|
def _handle_clarification(self, request: ToolCallRequest) -> Command:
|
||
|
|
"""Handle clarification request and return command to interrupt execution.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Tool call request
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Command that interrupts execution with the formatted clarification message
|
||
|
|
"""
|
||
|
|
# Extract clarification arguments
|
||
|
|
args = request.tool_call.get("args", {})
|
||
|
|
question = args.get("question", "")
|
||
|
|
|
||
|
|
print("[ClarificationMiddleware] Intercepted clarification request")
|
||
|
|
print(f"[ClarificationMiddleware] Question: {question}")
|
||
|
|
|
||
|
|
# Format the clarification message
|
||
|
|
formatted_message = self._format_clarification_message(args)
|
||
|
|
|
||
|
|
# Get the tool call ID
|
||
|
|
tool_call_id = request.tool_call.get("id", "")
|
||
|
|
|
||
|
|
# Create a ToolMessage with the formatted question
|
||
|
|
# This will be added to the message history
|
||
|
|
tool_message = ToolMessage(
|
||
|
|
content=formatted_message,
|
||
|
|
tool_call_id=tool_call_id,
|
||
|
|
name="ask_clarification",
|
||
|
|
)
|
||
|
|
|
||
|
|
ai_response_message = AIMessage(content=formatted_message)
|
||
|
|
|
||
|
|
# Return a Command that:
|
||
|
|
# 1. Adds the formatted tool message (keeping the AI message intact)
|
||
|
|
# 2. Interrupts execution by going to __end__
|
||
|
|
# Note: We don't modify the AI message to preserve all fields (reasoning_content, tool_calls, etc.)
|
||
|
|
# This is especially important for thinking mode where reasoning_content is required
|
||
|
|
|
||
|
|
# Return Command to add the tool message and interrupt
|
||
|
|
return Command(
|
||
|
|
update={"messages": [tool_message, ai_response_message]},
|
||
|
|
goto=END,
|
||
|
|
)
|
||
|
|
|
||
|
|
@override
|
||
|
|
def wrap_tool_call(
|
||
|
|
self,
|
||
|
|
request: ToolCallRequest,
|
||
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||
|
|
) -> ToolMessage | Command:
|
||
|
|
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Tool call request
|
||
|
|
handler: Original tool execution handler
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Command that interrupts execution with the formatted clarification message
|
||
|
|
"""
|
||
|
|
# Check if this is an ask_clarification tool call
|
||
|
|
if request.tool_call.get("name") != "ask_clarification":
|
||
|
|
# Not a clarification call, execute normally
|
||
|
|
return handler(request)
|
||
|
|
|
||
|
|
return self._handle_clarification(request)
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def awrap_tool_call(
|
||
|
|
self,
|
||
|
|
request: ToolCallRequest,
|
||
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||
|
|
) -> ToolMessage | Command:
|
||
|
|
"""Intercept ask_clarification tool calls and interrupt execution (async version).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Tool call request
|
||
|
|
handler: Original tool execution handler (async)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Command that interrupts execution with the formatted clarification message
|
||
|
|
"""
|
||
|
|
# Check if this is an ask_clarification tool call
|
||
|
|
if request.tool_call.get("name") != "ask_clarification":
|
||
|
|
# Not a clarification call, execute normally
|
||
|
|
return await handler(request)
|
||
|
|
|
||
|
|
return self._handle_clarification(request)
|