feat: implement tool-specific interrupts for create_react_agent (#572) (#659)

* feat: implement tool-specific interrupts for create_react_agent (#572)

          Add selective tool interrupt capability allowing interrupts before specific tools
          rather than all tools. Users can now configure which tools trigger interrupts via
          the interrupt_before_tools parameter.

          Changes:
          - Create ToolInterceptor class to handle tool-specific interrupt logic
          - Add interrupt_before_tools parameter to create_agent() function
          - Extend Configuration with interrupt_before_tools field
          - Add interrupt_before_tools to ChatRequest API
          - Update nodes.py to pass interrupt configuration to agents
          - Update app.py workflow to support tool interrupt configuration
          - Add comprehensive unit tests for tool interceptor

          Features:
          - Selective tool interrupts: interrupt only specific tools by name
          - Approval keywords: recognize user approval (approved, proceed, accept, etc.)
          - Backward compatible: optional parameter, existing code unaffected
          - Flexible: works with default tools and MCP-powered tools
          - Works with existing resume mechanism for seamless workflow

          Example usage:
            request = ChatRequest(
              messages=[...],
              interrupt_before_tools=['db_tool', 'sensitive_api']
            )

* test: add comprehensive integration tests for tool-specific interrupts (#572)

Add 24 integration tests covering all aspects of the tool interceptor feature:

Test Coverage:
- Agent creation with tool interrupts
- Configuration support (with/without interrupts)
- ChatRequest API integration
- Multiple tools with selective interrupts
- User approval/rejection flows
- Tool wrapping and functionality preservation
- Error handling and edge cases
- Approval keyword recognition
- Complex tool inputs
- Logging and monitoring

All tests pass with 100% coverage of tool interceptor functionality.

Tests verify:
✓ Selective tool interrupts work correctly
✓ Only specified tools trigger interrupts
✓ Non-matching tools execute normally
✓ User feedback is properly parsed
✓ Tool functionality is preserved after wrapping
✓ Error handling works as expected
✓ Configuration options are properly respected
✓ Logging provides useful debugging info

* fix: mock get_llm_by_type in agent creation test

Fix test_agent_creation_with_tool_interrupts which was failing because
get_llm_by_type() was being called before create_react_agent was mocked.

Changes:
- Add mock for get_llm_by_type in test
- Use context manager composition for multiple patches
- Test now passes and validates tool wrapping correctly

All 24 integration tests now pass successfully.

* refactor: use mock assertion methods for consistent and clearer error messages

Update integration tests to use mock assertion methods instead of direct
attribute checking for consistency and clearer error messages:

Changes:
- Replace 'assert mock_interrupt.called' with 'mock_interrupt.assert_called()'
- Replace 'assert not mock_interrupt.called' with 'mock_interrupt.assert_not_called()'

Benefits:
- Consistent with pytest-mock and unittest.mock best practices
- Clearer error messages when assertions fail
- Better IDE autocompletion support
- More professional test code

All 42 tests pass with improved assertion patterns.

* refactor: use default_factory for interrupt_before_tools consistency

Improve consistency between ChatRequest and Configuration implementations:

Changes:
- ChatRequest.interrupt_before_tools: Use Field(default_factory=list) instead of Optional[None]
- Remove unnecessary 'or []' conversion in app.py line 505
- Aligns with Configuration.interrupt_before_tools implementation pattern
- No functional changes - all tests still pass

Benefits:
- Consistent field definition across codebase
- Simpler and cleaner code
- Reduced chance of None/empty list bugs
- Better alignment with Pydantic best practices

All 42 tests passing.

* refactor: improve tool input formatting in interrupt messages

Enhance tool input representation for better readability in interrupt messages:

Changes:
- Add json import for better formatting
- Create _format_tool_input() static method with JSON serialization
- Use JSON formatting for dicts, lists, tuples with indent=2
- Fall back to str() for non-serializable types
- Handle None input specially (returns 'No input')
- Improve interrupt message formatting with better spacing

Benefits:
- Complex tool inputs now display as readable JSON
- Nested structures are properly indented and visible
- Better user experience when reviewing tool inputs before approval
- Handles edge cases gracefully with fallbacks
- Improved logging output for debugging

Example improvements:
Before: {'query': 'SELECT...', 'limit': 10, 'nested': {'key': 'value'}}
After:
{
  "query": "SELECT...",
  "limit": 10,
  "nested": {
    "key": "value"
  }
}

All 42 tests still passing.

* test: add comprehensive unit tests for tool input formatting
This commit is contained in:
Willem Jiang
2025-10-26 09:47:03 +08:00
committed by GitHub
parent 0441038672
commit bcc403ecd3
8 changed files with 1163 additions and 5 deletions

View File

@@ -1,11 +1,17 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
from typing import List, Optional
from langgraph.prebuilt import create_react_agent
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts import apply_prompt_template
from src.agents.tool_interceptor import wrap_tools_with_interceptor
logger = logging.getLogger(__name__)
# Create agents using configured LLM types
@@ -15,12 +21,33 @@ def create_agent(
tools: list,
prompt_template: str,
pre_model_hook: callable = None,
interrupt_before_tools: Optional[List[str]] = None,
):
"""Factory function to create agents with consistent configuration."""
"""Factory function to create agents with consistent configuration.
Args:
agent_name: Name of the agent
agent_type: Type of agent (researcher, coder, etc.)
tools: List of tools available to the agent
prompt_template: Name of the prompt template to use
pre_model_hook: Optional hook to preprocess state before model invocation
interrupt_before_tools: Optional list of tool names to interrupt before execution
Returns:
A configured agent graph
"""
# Wrap tools with interrupt logic if specified
processed_tools = tools
if interrupt_before_tools:
logger.info(
f"Creating agent '{agent_name}' with tool-specific interrupts: {interrupt_before_tools}"
)
processed_tools = wrap_tools_with_interceptor(tools, interrupt_before_tools)
return create_react_agent(
name=agent_name,
model=get_llm_by_type(AGENT_LLM_MAP[agent_type]),
tools=tools,
tools=processed_tools,
prompt=lambda state: apply_prompt_template(
prompt_template, state, locale=state.get("locale", "en-US")
),

View File

@@ -0,0 +1,205 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import logging
from typing import Any, Callable, List, Optional
from langchain_core.tools import BaseTool
from langgraph.types import interrupt
logger = logging.getLogger(__name__)
class ToolInterceptor:
"""Intercepts tool calls and triggers interrupts for specified tools."""
def __init__(self, interrupt_before_tools: Optional[List[str]] = None):
"""Initialize the interceptor with list of tools to interrupt before.
Args:
interrupt_before_tools: List of tool names to interrupt before execution.
If None or empty, no interrupts are triggered.
"""
self.interrupt_before_tools = interrupt_before_tools or []
logger.info(
f"ToolInterceptor initialized with interrupt_before_tools: {self.interrupt_before_tools}"
)
def should_interrupt(self, tool_name: str) -> bool:
"""Check if execution should be interrupted before this tool.
Args:
tool_name: Name of the tool being called
Returns:
bool: True if tool should trigger an interrupt, False otherwise
"""
should_interrupt = tool_name in self.interrupt_before_tools
if should_interrupt:
logger.info(f"Tool '{tool_name}' marked for interrupt")
return should_interrupt
@staticmethod
def _format_tool_input(tool_input: Any) -> str:
"""Format tool input for display in interrupt messages.
Attempts to format as JSON for better readability, with fallback to string representation.
Args:
tool_input: The tool input to format
Returns:
str: Formatted representation of the tool input
"""
if tool_input is None:
return "No input"
# Try to serialize as JSON first for better readability
try:
# Handle dictionaries and other JSON-serializable objects
if isinstance(tool_input, (dict, list, tuple)):
return json.dumps(tool_input, indent=2, default=str)
elif isinstance(tool_input, str):
return tool_input
else:
# For other types, try to convert to dict if it has __dict__
# Otherwise fall back to string representation
return str(tool_input)
except (TypeError, ValueError):
# JSON serialization failed, use string representation
return str(tool_input)
@staticmethod
def wrap_tool(
tool: BaseTool, interceptor: "ToolInterceptor"
) -> BaseTool:
"""Wrap a tool to add interrupt logic by creating a wrapper.
Args:
tool: The tool to wrap
interceptor: The ToolInterceptor instance
Returns:
BaseTool: The wrapped tool with interrupt capability
"""
original_func = tool.func
def intercepted_func(*args: Any, **kwargs: Any) -> Any:
"""Execute the tool with interrupt check."""
tool_name = tool.name
# Format tool input for display
tool_input = args[0] if args else kwargs
tool_input_repr = ToolInterceptor._format_tool_input(tool_input)
if interceptor.should_interrupt(tool_name):
logger.info(
f"Interrupting before tool '{tool_name}' with input: {tool_input_repr}"
)
# Trigger interrupt and wait for user feedback
feedback = interrupt(
f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?"
)
logger.info(f"Interrupt feedback for '{tool_name}': {feedback}")
# Check if user approved
if not ToolInterceptor._parse_approval(feedback):
logger.warning(f"User rejected execution of tool '{tool_name}'")
return {
"error": f"Tool execution rejected by user",
"tool": tool_name,
"status": "rejected",
}
logger.info(f"User approved execution of tool '{tool_name}'")
# Execute the original tool
try:
result = original_func(*args, **kwargs)
logger.debug(f"Tool '{tool_name}' execution completed")
return result
except Exception as e:
logger.error(f"Error executing tool '{tool_name}': {str(e)}")
raise
# Replace the function and update the tool
# Use object.__setattr__ to bypass Pydantic validation
object.__setattr__(tool, "func", intercepted_func)
return tool
@staticmethod
def _parse_approval(feedback: str) -> bool:
"""Parse user feedback to determine if tool execution was approved.
Args:
feedback: The feedback string from the user
Returns:
bool: True if feedback indicates approval, False otherwise
"""
if not feedback:
logger.warning("Empty feedback received, treating as rejection")
return False
feedback_lower = feedback.lower().strip()
# Check for approval keywords
approval_keywords = [
"approved",
"approve",
"yes",
"proceed",
"continue",
"ok",
"okay",
"accepted",
"accept",
"[approved]",
]
for keyword in approval_keywords:
if keyword in feedback_lower:
return True
# Default to rejection if no approval keywords found
logger.warning(
f"No approval keywords found in feedback: {feedback}. Treating as rejection."
)
return False
def wrap_tools_with_interceptor(
tools: List[BaseTool], interrupt_before_tools: Optional[List[str]] = None
) -> List[BaseTool]:
"""Wrap multiple tools with interrupt logic.
Args:
tools: List of tools to wrap
interrupt_before_tools: List of tool names to interrupt before
Returns:
List[BaseTool]: List of wrapped tools
"""
if not interrupt_before_tools:
logger.debug("No tool interrupts configured, returning tools as-is")
return tools
logger.info(
f"Wrapping {len(tools)} tools with interrupt logic for: {interrupt_before_tools}"
)
interceptor = ToolInterceptor(interrupt_before_tools)
wrapped_tools = []
for tool in tools:
try:
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
wrapped_tools.append(wrapped_tool)
logger.debug(f"Wrapped tool: {tool.name}")
except Exception as e:
logger.error(f"Failed to wrap tool {tool.name}: {str(e)}")
# Add original tool if wrapping fails
wrapped_tools.append(tool)
logger.info(f"Successfully wrapped {len(wrapped_tools)} tools")
return wrapped_tools

View File

@@ -54,6 +54,9 @@ class Configuration:
enforce_web_search: bool = (
False # Enforce at least one web search step in every plan
)
interrupt_before_tools: list[str] = field(
default_factory=list
) # List of tool names to interrupt before execution
@classmethod
def from_runnable_config(

View File

@@ -914,7 +914,12 @@ async def _setup_and_execute_agent_step(
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, loaded_tools, agent_type, pre_model_hook
agent_type,
agent_type,
loaded_tools,
agent_type,
pre_model_hook,
interrupt_before_tools=configurable.interrupt_before_tools,
)
return await _execute_agent_step(state, agent, agent_type)
else:
@@ -922,7 +927,12 @@ async def _setup_and_execute_agent_step(
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, default_tools, agent_type, pre_model_hook
agent_type,
agent_type,
default_tools,
agent_type,
pre_model_hook,
interrupt_before_tools=configurable.interrupt_before_tools,
)
return await _execute_agent_step(state, agent, agent_type)

View File

@@ -6,7 +6,7 @@ import base64
import json
import logging
import os
from typing import Annotated, Any, List, cast
from typing import Annotated, Any, List, Optional, cast
from uuid import uuid4
from fastapi import FastAPI, HTTPException, Query
@@ -127,6 +127,7 @@ async def chat_stream(request: ChatRequest):
request.enable_clarification,
request.max_clarification_rounds,
request.locale,
request.interrupt_before_tools,
),
media_type="text/event-stream",
)
@@ -453,6 +454,7 @@ async def _astream_workflow_generator(
enable_clarification: bool,
max_clarification_rounds: int,
locale: str = "en-US",
interrupt_before_tools: Optional[List[str]] = None,
):
# Process initial messages
for message in messages:
@@ -500,6 +502,7 @@ async def _astream_workflow_generator(
"mcp_settings": mcp_settings,
"report_style": report_style.value,
"enable_deep_thinking": enable_deep_thinking,
"interrupt_before_tools": interrupt_before_tools,
"recursion_limit": get_recursion_limit(),
}

View File

@@ -76,6 +76,10 @@ class ChatRequest(BaseModel):
None,
description="Maximum number of clarification rounds (default: None, uses State default=3)",
)
interrupt_before_tools: List[str] = Field(
default_factory=list,
description="List of tool names to interrupt before execution (e.g., ['db_tool', 'api_tool'])",
)
class TTSRequest(BaseModel):