"""Middleware for intercepting clarification requests and presenting them to the user.""" from collections.abc import Callable from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import ToolMessage from langgraph.graph import END from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.types import Command class ClarificationMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" pass class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]): """Intercepts clarification tool calls and interrupts execution to present questions to the user. When the model calls the `ask_clarification` tool, this middleware: 1. Intercepts the tool call before execution 2. Extracts the clarification question and metadata 3. Formats a user-friendly message 4. Returns a Command that interrupts execution and presents the question 5. Waits for user response before continuing This replaces the tool-based approach where clarification continued the conversation flow. """ state_schema = ClarificationMiddlewareState def _is_chinese(self, text: str) -> bool: """Check if text contains Chinese characters. Args: text: Text to check Returns: True if text contains Chinese characters """ return any("\u4e00" <= char <= "\u9fff" for char in text) def _format_clarification_message(self, args: dict) -> str: """Format the clarification arguments into a user-friendly message. Args: args: The tool call arguments containing clarification details Returns: Formatted message string """ question = args.get("question", "") clarification_type = args.get("clarification_type", "missing_info") context = args.get("context") options = args.get("options", []) # Type-specific icons type_icons = { "missing_info": "❓", "ambiguous_requirement": "🤔", "approach_choice": "🔀", "risk_confirmation": "⚠️", "suggestion": "💡", } icon = type_icons.get(clarification_type, "❓") # Build the message naturally message_parts = [] # Add icon and question together for a more natural flow if context: # If there's context, present it first as background message_parts.append(f"{icon} {context}") message_parts.append(f"\n{question}") else: # Just the question with icon message_parts.append(f"{icon} {question}") # Add options in a cleaner format if options and len(options) > 0: message_parts.append("") # blank line for spacing for i, option in enumerate(options, 1): message_parts.append(f" {i}. {option}") return "\n".join(message_parts) def _handle_clarification(self, request: ToolCallRequest) -> Command: """Handle clarification request and return command to interrupt execution. Args: request: Tool call request Returns: Command that interrupts execution with the formatted clarification message """ # Extract clarification arguments args = request.tool_call.get("args", {}) question = args.get("question", "") print("[ClarificationMiddleware] Intercepted clarification request") print(f"[ClarificationMiddleware] Question: {question}") # Format the clarification message formatted_message = self._format_clarification_message(args) # Get the tool call ID tool_call_id = request.tool_call.get("id", "") # Create a ToolMessage with the formatted question # This will be added to the message history tool_message = ToolMessage( content=formatted_message, tool_call_id=tool_call_id, name="ask_clarification", ) # Return a Command that: # 1. Adds the formatted tool message # 2. Interrupts execution by going to __end__ # Note: We don't add an extra AIMessage here - the frontend will detect # and display ask_clarification tool messages directly return Command( update={"messages": [tool_message]}, goto=END, ) @override def wrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: """Intercept ask_clarification tool calls and interrupt execution (sync version). Args: request: Tool call request handler: Original tool execution handler Returns: Command that interrupts execution with the formatted clarification message """ # Check if this is an ask_clarification tool call if request.tool_call.get("name") != "ask_clarification": # Not a clarification call, execute normally return handler(request) return self._handle_clarification(request) @override async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: """Intercept ask_clarification tool calls and interrupt execution (async version). Args: request: Tool call request handler: Original tool execution handler (async) Returns: Command that interrupts execution with the formatted clarification message """ # Check if this is an ask_clarification tool call if request.tool_call.get("name") != "ask_clarification": # Not a clarification call, execute normally return await handler(request) return self._handle_clarification(request)