"""Middleware for intercepting clarification requests and presenting them to the user.""" import logging from collections.abc import Callable from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import ToolMessage from langgraph.graph import END from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.types import Command logger = logging.getLogger(__name__) class ClarificationMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" pass class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]): """Intercepts clarification tool calls and interrupts execution to present questions to the user. When the model calls the `ask_clarification` tool, this middleware: 1. Intercepts the tool call before execution 2. Extracts the clarification question and metadata 3. Formats a user-friendly message 4. Returns a Command that interrupts execution and presents the question 5. Waits for user response before continuing This replaces the tool-based approach where clarification continued the conversation flow. """ state_schema = ClarificationMiddlewareState def _is_chinese(self, text: str) -> bool: """Check if text contains Chinese characters. Args: text: Text to check Returns: True if text contains Chinese characters """ return any("\u4e00" <= char <= "\u9fff" for char in text) def _format_clarification_message(self, args: dict) -> str: """Format the clarification arguments into a user-friendly message. Args: args: The tool call arguments containing clarification details Returns: Formatted message string """ question = args.get("question", "") clarification_type = args.get("clarification_type", "missing_info") context = args.get("context") options = args.get("options", []) # Type-specific icons type_icons = { "missing_info": "❓", "ambiguous_requirement": "🤔", "approach_choice": "🔀", "risk_confirmation": "⚠️", "suggestion": "💡", } icon = type_icons.get(clarification_type, "❓") # Build the message naturally message_parts = [] # Add icon and question together for a more natural flow if context: # If there's context, present it first as background message_parts.append(f"{icon} {context}") message_parts.append(f"\n{question}") else: # Just the question with icon message_parts.append(f"{icon} {question}") # Add options in a cleaner format if options and len(options) > 0: message_parts.append("") # blank line for spacing for i, option in enumerate(options, 1): message_parts.append(f" {i}. {option}") return "\n".join(message_parts) def _handle_clarification(self, request: ToolCallRequest) -> Command: """Handle clarification request and return command to interrupt execution. Args: request: Tool call request Returns: Command that interrupts execution with the formatted clarification message """ # Extract clarification arguments args = request.tool_call.get("args", {}) question = args.get("question", "") logger.info("Intercepted clarification request") logger.debug("Clarification question: %s", question) # Format the clarification message formatted_message = self._format_clarification_message(args) # Get the tool call ID tool_call_id = request.tool_call.get("id", "") # Create a ToolMessage with the formatted question # This will be added to the message history tool_message = ToolMessage( content=formatted_message, tool_call_id=tool_call_id, name="ask_clarification", ) # Return a Command that: # 1. Adds the formatted tool message # 2. Interrupts execution by going to __end__ # Note: We don't add an extra AIMessage here - the frontend will detect # and display ask_clarification tool messages directly return Command( update={"messages": [tool_message]}, goto=END, ) @override def wrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: """Intercept ask_clarification tool calls and interrupt execution (sync version). Args: request: Tool call request handler: Original tool execution handler Returns: Command that interrupts execution with the formatted clarification message """ # Check if this is an ask_clarification tool call if request.tool_call.get("name") != "ask_clarification": # Not a clarification call, execute normally return handler(request) return self._handle_clarification(request) @override async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: """Intercept ask_clarification tool calls and interrupt execution (async version). Args: request: Tool call request handler: Original tool execution handler (async) Returns: Command that interrupts execution with the formatted clarification message """ # Check if this is an ask_clarification tool call if request.tool_call.get("name") != "ask_clarification": # Not a clarification call, execute normally return await handler(request) return self._handle_clarification(request)