diff --git a/src/agents/agents.py b/src/agents/agents.py index 0bb0e84..a5c16dd 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -1,11 +1,17 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import logging +from typing import List, Optional + from langgraph.prebuilt import create_react_agent from src.config.agents import AGENT_LLM_MAP from src.llms.llm import get_llm_by_type from src.prompts import apply_prompt_template +from src.agents.tool_interceptor import wrap_tools_with_interceptor + +logger = logging.getLogger(__name__) # Create agents using configured LLM types @@ -15,12 +21,33 @@ def create_agent( tools: list, prompt_template: str, pre_model_hook: callable = None, + interrupt_before_tools: Optional[List[str]] = None, ): - """Factory function to create agents with consistent configuration.""" + """Factory function to create agents with consistent configuration. + + Args: + agent_name: Name of the agent + agent_type: Type of agent (researcher, coder, etc.) + tools: List of tools available to the agent + prompt_template: Name of the prompt template to use + pre_model_hook: Optional hook to preprocess state before model invocation + interrupt_before_tools: Optional list of tool names to interrupt before execution + + Returns: + A configured agent graph + """ + # Wrap tools with interrupt logic if specified + processed_tools = tools + if interrupt_before_tools: + logger.info( + f"Creating agent '{agent_name}' with tool-specific interrupts: {interrupt_before_tools}" + ) + processed_tools = wrap_tools_with_interceptor(tools, interrupt_before_tools) + return create_react_agent( name=agent_name, model=get_llm_by_type(AGENT_LLM_MAP[agent_type]), - tools=tools, + tools=processed_tools, prompt=lambda state: apply_prompt_template( prompt_template, state, locale=state.get("locale", "en-US") ), diff --git a/src/agents/tool_interceptor.py b/src/agents/tool_interceptor.py new file mode 100644 index 0000000..2adfe40 --- /dev/null +++ b/src/agents/tool_interceptor.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import json +import logging +from typing import Any, Callable, List, Optional + +from langchain_core.tools import BaseTool +from langgraph.types import interrupt + +logger = logging.getLogger(__name__) + + +class ToolInterceptor: + """Intercepts tool calls and triggers interrupts for specified tools.""" + + def __init__(self, interrupt_before_tools: Optional[List[str]] = None): + """Initialize the interceptor with list of tools to interrupt before. + + Args: + interrupt_before_tools: List of tool names to interrupt before execution. + If None or empty, no interrupts are triggered. + """ + self.interrupt_before_tools = interrupt_before_tools or [] + logger.info( + f"ToolInterceptor initialized with interrupt_before_tools: {self.interrupt_before_tools}" + ) + + def should_interrupt(self, tool_name: str) -> bool: + """Check if execution should be interrupted before this tool. + + Args: + tool_name: Name of the tool being called + + Returns: + bool: True if tool should trigger an interrupt, False otherwise + """ + should_interrupt = tool_name in self.interrupt_before_tools + if should_interrupt: + logger.info(f"Tool '{tool_name}' marked for interrupt") + return should_interrupt + + @staticmethod + def _format_tool_input(tool_input: Any) -> str: + """Format tool input for display in interrupt messages. + + Attempts to format as JSON for better readability, with fallback to string representation. + + Args: + tool_input: The tool input to format + + Returns: + str: Formatted representation of the tool input + """ + if tool_input is None: + return "No input" + + # Try to serialize as JSON first for better readability + try: + # Handle dictionaries and other JSON-serializable objects + if isinstance(tool_input, (dict, list, tuple)): + return json.dumps(tool_input, indent=2, default=str) + elif isinstance(tool_input, str): + return tool_input + else: + # For other types, try to convert to dict if it has __dict__ + # Otherwise fall back to string representation + return str(tool_input) + except (TypeError, ValueError): + # JSON serialization failed, use string representation + return str(tool_input) + + @staticmethod + def wrap_tool( + tool: BaseTool, interceptor: "ToolInterceptor" + ) -> BaseTool: + """Wrap a tool to add interrupt logic by creating a wrapper. + + Args: + tool: The tool to wrap + interceptor: The ToolInterceptor instance + + Returns: + BaseTool: The wrapped tool with interrupt capability + """ + original_func = tool.func + + def intercepted_func(*args: Any, **kwargs: Any) -> Any: + """Execute the tool with interrupt check.""" + tool_name = tool.name + # Format tool input for display + tool_input = args[0] if args else kwargs + tool_input_repr = ToolInterceptor._format_tool_input(tool_input) + + if interceptor.should_interrupt(tool_name): + logger.info( + f"Interrupting before tool '{tool_name}' with input: {tool_input_repr}" + ) + # Trigger interrupt and wait for user feedback + feedback = interrupt( + f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?" + ) + + logger.info(f"Interrupt feedback for '{tool_name}': {feedback}") + + # Check if user approved + if not ToolInterceptor._parse_approval(feedback): + logger.warning(f"User rejected execution of tool '{tool_name}'") + return { + "error": f"Tool execution rejected by user", + "tool": tool_name, + "status": "rejected", + } + + logger.info(f"User approved execution of tool '{tool_name}'") + + # Execute the original tool + try: + result = original_func(*args, **kwargs) + logger.debug(f"Tool '{tool_name}' execution completed") + return result + except Exception as e: + logger.error(f"Error executing tool '{tool_name}': {str(e)}") + raise + + # Replace the function and update the tool + # Use object.__setattr__ to bypass Pydantic validation + object.__setattr__(tool, "func", intercepted_func) + return tool + + @staticmethod + def _parse_approval(feedback: str) -> bool: + """Parse user feedback to determine if tool execution was approved. + + Args: + feedback: The feedback string from the user + + Returns: + bool: True if feedback indicates approval, False otherwise + """ + if not feedback: + logger.warning("Empty feedback received, treating as rejection") + return False + + feedback_lower = feedback.lower().strip() + + # Check for approval keywords + approval_keywords = [ + "approved", + "approve", + "yes", + "proceed", + "continue", + "ok", + "okay", + "accepted", + "accept", + "[approved]", + ] + + for keyword in approval_keywords: + if keyword in feedback_lower: + return True + + # Default to rejection if no approval keywords found + logger.warning( + f"No approval keywords found in feedback: {feedback}. Treating as rejection." + ) + return False + + +def wrap_tools_with_interceptor( + tools: List[BaseTool], interrupt_before_tools: Optional[List[str]] = None +) -> List[BaseTool]: + """Wrap multiple tools with interrupt logic. + + Args: + tools: List of tools to wrap + interrupt_before_tools: List of tool names to interrupt before + + Returns: + List[BaseTool]: List of wrapped tools + """ + if not interrupt_before_tools: + logger.debug("No tool interrupts configured, returning tools as-is") + return tools + + logger.info( + f"Wrapping {len(tools)} tools with interrupt logic for: {interrupt_before_tools}" + ) + interceptor = ToolInterceptor(interrupt_before_tools) + + wrapped_tools = [] + for tool in tools: + try: + wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor) + wrapped_tools.append(wrapped_tool) + logger.debug(f"Wrapped tool: {tool.name}") + except Exception as e: + logger.error(f"Failed to wrap tool {tool.name}: {str(e)}") + # Add original tool if wrapping fails + wrapped_tools.append(tool) + + logger.info(f"Successfully wrapped {len(wrapped_tools)} tools") + return wrapped_tools diff --git a/src/config/configuration.py b/src/config/configuration.py index 093b7f1..299d6cd 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -54,6 +54,9 @@ class Configuration: enforce_web_search: bool = ( False # Enforce at least one web search step in every plan ) + interrupt_before_tools: list[str] = field( + default_factory=list + ) # List of tool names to interrupt before execution @classmethod def from_runnable_config( diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 726d94e..b80990d 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -914,7 +914,12 @@ async def _setup_and_execute_agent_step( llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type]) pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages) agent = create_agent( - agent_type, agent_type, loaded_tools, agent_type, pre_model_hook + agent_type, + agent_type, + loaded_tools, + agent_type, + pre_model_hook, + interrupt_before_tools=configurable.interrupt_before_tools, ) return await _execute_agent_step(state, agent, agent_type) else: @@ -922,7 +927,12 @@ async def _setup_and_execute_agent_step( llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type]) pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages) agent = create_agent( - agent_type, agent_type, default_tools, agent_type, pre_model_hook + agent_type, + agent_type, + default_tools, + agent_type, + pre_model_hook, + interrupt_before_tools=configurable.interrupt_before_tools, ) return await _execute_agent_step(state, agent, agent_type) diff --git a/src/server/app.py b/src/server/app.py index d650721..b02e15e 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -6,7 +6,7 @@ import base64 import json import logging import os -from typing import Annotated, Any, List, cast +from typing import Annotated, Any, List, Optional, cast from uuid import uuid4 from fastapi import FastAPI, HTTPException, Query @@ -127,6 +127,7 @@ async def chat_stream(request: ChatRequest): request.enable_clarification, request.max_clarification_rounds, request.locale, + request.interrupt_before_tools, ), media_type="text/event-stream", ) @@ -453,6 +454,7 @@ async def _astream_workflow_generator( enable_clarification: bool, max_clarification_rounds: int, locale: str = "en-US", + interrupt_before_tools: Optional[List[str]] = None, ): # Process initial messages for message in messages: @@ -500,6 +502,7 @@ async def _astream_workflow_generator( "mcp_settings": mcp_settings, "report_style": report_style.value, "enable_deep_thinking": enable_deep_thinking, + "interrupt_before_tools": interrupt_before_tools, "recursion_limit": get_recursion_limit(), } diff --git a/src/server/chat_request.py b/src/server/chat_request.py index e90b553..3dc18f4 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -76,6 +76,10 @@ class ChatRequest(BaseModel): None, description="Maximum number of clarification rounds (default: None, uses State default=3)", ) + interrupt_before_tools: List[str] = Field( + default_factory=list, + description="List of tool names to interrupt before execution (e.g., ['db_tool', 'api_tool'])", + ) class TTSRequest(BaseModel): diff --git a/tests/integration/test_tool_interceptor_integration.py b/tests/integration/test_tool_interceptor_integration.py new file mode 100644 index 0000000..a28fee4 --- /dev/null +++ b/tests/integration/test_tool_interceptor_integration.py @@ -0,0 +1,473 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Integration tests for tool-specific interrupts feature (Issue #572). + +Tests the complete flow of selective tool interrupts including: +- Tool wrapping with interrupt logic +- Agent creation with interrupt configuration +- Tool execution with user feedback +- Resume mechanism after interrupt +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock, MagicMock, call +from typing import Any + +from langchain_core.tools import tool +from langchain_core.messages import HumanMessage + +from src.agents.agents import create_agent +from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor +from src.config.configuration import Configuration +from src.server.chat_request import ChatRequest + + +class TestToolInterceptorIntegration: + """Integration tests for tool interceptor with agent workflow.""" + + def test_agent_creation_with_tool_interrupts(self): + """Test creating an agent with tool interrupts configured.""" + @tool + def search_tool(query: str) -> str: + """Search the web.""" + return f"Search results for: {query}" + + @tool + def db_tool(query: str) -> str: + """Query database.""" + return f"DB results for: {query}" + + tools = [search_tool, db_tool] + + # Create agent with interrupts on db_tool only + with patch("src.agents.agents.create_react_agent") as mock_create, \ + patch("src.agents.agents.get_llm_by_type") as mock_llm: + mock_create.return_value = MagicMock() + mock_llm.return_value = MagicMock() + + agent = create_agent( + agent_name="test_agent", + agent_type="researcher", + tools=tools, + prompt_template="researcher", + interrupt_before_tools=["db_tool"], + ) + + # Verify create_react_agent was called with wrapped tools + assert mock_create.called + call_args = mock_create.call_args + wrapped_tools = call_args.kwargs["tools"] + + # Should have wrapped the tools + assert len(wrapped_tools) == 2 + assert wrapped_tools[0].name == "search_tool" + assert wrapped_tools[1].name == "db_tool" + + def test_configuration_with_tool_interrupts(self): + """Test Configuration object with interrupt_before_tools.""" + config = Configuration( + interrupt_before_tools=["db_tool", "api_write_tool"], + max_step_num=3, + max_search_results=5, + ) + + assert config.interrupt_before_tools == ["db_tool", "api_write_tool"] + assert config.max_step_num == 3 + assert config.max_search_results == 5 + + def test_configuration_default_no_interrupts(self): + """Test Configuration defaults to no interrupts.""" + config = Configuration() + assert config.interrupt_before_tools == [] + + def test_chat_request_with_tool_interrupts(self): + """Test ChatRequest with interrupt_before_tools.""" + request = ChatRequest( + messages=[{"role": "user", "content": "Search for X"}], + interrupt_before_tools=["db_tool", "payment_api"], + ) + + assert request.interrupt_before_tools == ["db_tool", "payment_api"] + + def test_chat_request_interrupt_feedback_with_tool_interrupts(self): + """Test ChatRequest with both interrupt_before_tools and interrupt_feedback.""" + request = ChatRequest( + messages=[{"role": "user", "content": "Research topic"}], + interrupt_before_tools=["db_tool"], + interrupt_feedback="approved", + ) + + assert request.interrupt_before_tools == ["db_tool"] + assert request.interrupt_feedback == "approved" + + def test_multiple_tools_selective_interrupt(self): + """Test that only specified tools trigger interrupts.""" + @tool + def tool_a(x: str) -> str: + """Tool A""" + return f"A: {x}" + + @tool + def tool_b(x: str) -> str: + """Tool B""" + return f"B: {x}" + + @tool + def tool_c(x: str) -> str: + """Tool C""" + return f"C: {x}" + + tools = [tool_a, tool_b, tool_c] + interceptor = ToolInterceptor(["tool_b"]) + + # Wrap all tools + wrapped_tools = wrap_tools_with_interceptor(tools, ["tool_b"]) + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + # tool_a should not interrupt + mock_interrupt.return_value = "approved" + result_a = wrapped_tools[0].invoke("test") + mock_interrupt.assert_not_called() + + # tool_b should interrupt + result_b = wrapped_tools[1].invoke("test") + mock_interrupt.assert_called() + + # tool_c should not interrupt + mock_interrupt.reset_mock() + result_c = wrapped_tools[2].invoke("test") + mock_interrupt.assert_not_called() + + def test_interrupt_with_user_approval(self): + """Test interrupt flow with user approval.""" + @tool + def sensitive_tool(action: str) -> str: + """A sensitive tool.""" + return f"Executed: {action}" + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "approved" + + interceptor = ToolInterceptor(["sensitive_tool"]) + wrapped = ToolInterceptor.wrap_tool(sensitive_tool, interceptor) + + result = wrapped.invoke("delete_data") + + mock_interrupt.assert_called() + assert "Executed: delete_data" in str(result) + + def test_interrupt_with_user_rejection(self): + """Test interrupt flow with user rejection.""" + @tool + def sensitive_tool(action: str) -> str: + """A sensitive tool.""" + return f"Executed: {action}" + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "rejected" + + interceptor = ToolInterceptor(["sensitive_tool"]) + wrapped = ToolInterceptor.wrap_tool(sensitive_tool, interceptor) + + result = wrapped.invoke("delete_data") + + mock_interrupt.assert_called() + assert isinstance(result, dict) + assert "error" in result + assert result["status"] == "rejected" + + def test_interrupt_message_contains_tool_info(self): + """Test that interrupt message contains tool name and input.""" + @tool + def db_query_tool(query: str) -> str: + """Database query tool.""" + return f"Query result: {query}" + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "approved" + + interceptor = ToolInterceptor(["db_query_tool"]) + wrapped = ToolInterceptor.wrap_tool(db_query_tool, interceptor) + + wrapped.invoke("SELECT * FROM users") + + # Verify interrupt was called with meaningful message + mock_interrupt.assert_called() + interrupt_message = mock_interrupt.call_args[0][0] + assert "db_query_tool" in interrupt_message + assert "SELECT * FROM users" in interrupt_message + + def test_tool_wrapping_preserves_functionality(self): + """Test that tool wrapping preserves original tool functionality.""" + @tool + def simple_tool(text: str) -> str: + """Process text.""" + return f"Processed: {text}" + + interceptor = ToolInterceptor([]) # No interrupts + wrapped = ToolInterceptor.wrap_tool(simple_tool, interceptor) + + result = wrapped.invoke({"text": "hello"}) + assert "hello" in str(result) + + def test_tool_wrapping_preserves_tool_metadata(self): + """Test that tool wrapping preserves tool name and description.""" + @tool + def my_special_tool(x: str) -> str: + """This is my special tool description.""" + return f"Result: {x}" + + interceptor = ToolInterceptor([]) + wrapped = ToolInterceptor.wrap_tool(my_special_tool, interceptor) + + assert wrapped.name == "my_special_tool" + assert "special tool" in wrapped.description.lower() + + def test_multiple_interrupts_in_sequence(self): + """Test handling multiple tool interrupts in sequence.""" + @tool + def tool_one(x: str) -> str: + """Tool one.""" + return f"One: {x}" + + @tool + def tool_two(x: str) -> str: + """Tool two.""" + return f"Two: {x}" + + @tool + def tool_three(x: str) -> str: + """Tool three.""" + return f"Three: {x}" + + tools = [tool_one, tool_two, tool_three] + wrapped_tools = wrap_tools_with_interceptor( + tools, ["tool_one", "tool_two"] + ) + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "approved" + + # First interrupt + result_one = wrapped_tools[0].invoke("first") + assert mock_interrupt.call_count == 1 + + # Second interrupt + result_two = wrapped_tools[1].invoke("second") + assert mock_interrupt.call_count == 2 + + # Third (no interrupt) + result_three = wrapped_tools[2].invoke("third") + assert mock_interrupt.call_count == 2 + + assert "One: first" in str(result_one) + assert "Two: second" in str(result_two) + assert "Three: third" in str(result_three) + + def test_empty_interrupt_list_no_interrupts(self): + """Test that empty interrupt list doesn't trigger interrupts.""" + @tool + def test_tool(x: str) -> str: + """Test tool.""" + return f"Result: {x}" + + wrapped_tools = wrap_tools_with_interceptor([test_tool], []) + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + wrapped_tools[0].invoke("test") + mock_interrupt.assert_not_called() + + def test_none_interrupt_list_no_interrupts(self): + """Test that None interrupt list doesn't trigger interrupts.""" + @tool + def test_tool(x: str) -> str: + """Test tool.""" + return f"Result: {x}" + + wrapped_tools = wrap_tools_with_interceptor([test_tool], None) + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + wrapped_tools[0].invoke("test") + mock_interrupt.assert_not_called() + + def test_case_sensitive_tool_name_matching(self): + """Test that tool name matching is case-sensitive.""" + @tool + def MyTool(x: str) -> str: + """A tool.""" + return f"Result: {x}" + + interceptor_lower = ToolInterceptor(["mytool"]) + interceptor_exact = ToolInterceptor(["MyTool"]) + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "approved" + + # Case mismatch - should NOT interrupt + wrapped_lower = ToolInterceptor.wrap_tool(MyTool, interceptor_lower) + result_lower = wrapped_lower.invoke("test") + mock_interrupt.assert_not_called() + + # Case match - should interrupt + wrapped_exact = ToolInterceptor.wrap_tool(MyTool, interceptor_exact) + result_exact = wrapped_exact.invoke("test") + mock_interrupt.assert_called() + + def test_tool_error_handling(self): + """Test handling of tool errors during execution.""" + @tool + def error_tool(x: str) -> str: + """A tool that raises an error.""" + raise ValueError(f"Intentional error: {x}") + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "approved" + + interceptor = ToolInterceptor(["error_tool"]) + wrapped = ToolInterceptor.wrap_tool(error_tool, interceptor) + + with pytest.raises(ValueError) as exc_info: + wrapped.invoke("test") + + assert "Intentional error: test" in str(exc_info.value) + + def test_approval_keywords_comprehensive(self): + """Test all approved keywords are recognized.""" + approval_keywords = [ + "approved", + "approve", + "yes", + "proceed", + "continue", + "ok", + "okay", + "accepted", + "accept", + "[approved]", + "APPROVED", + "Proceed with this action", + "[ACCEPTED] I approve", + ] + + for keyword in approval_keywords: + result = ToolInterceptor._parse_approval(keyword) + assert ( + result is True + ), f"Keyword '{keyword}' should be approved but got {result}" + + def test_rejection_keywords_comprehensive(self): + """Test that rejection keywords are recognized.""" + rejection_keywords = [ + "no", + "reject", + "cancel", + "decline", + "stop", + "abort", + "maybe", + "later", + "random text", + "", + ] + + for keyword in rejection_keywords: + result = ToolInterceptor._parse_approval(keyword) + assert ( + result is False + ), f"Keyword '{keyword}' should be rejected but got {result}" + + def test_interrupt_with_complex_tool_input(self): + """Test interrupt with complex tool input types.""" + @tool + def complex_tool(data: str) -> str: + """A tool with complex input.""" + return f"Processed: {data}" + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "approved" + + interceptor = ToolInterceptor(["complex_tool"]) + wrapped = ToolInterceptor.wrap_tool(complex_tool, interceptor) + + complex_input = { + "data": "complex data with nested info" + } + + result = wrapped.invoke(complex_input) + + mock_interrupt.assert_called() + assert "Processed" in str(result) + + def test_configuration_from_runnable_config(self): + """Test Configuration.from_runnable_config with interrupt_before_tools.""" + from langchain_core.runnables import RunnableConfig + + config = RunnableConfig( + configurable={ + "interrupt_before_tools": ["db_tool"], + "max_step_num": 5, + } + ) + + configuration = Configuration.from_runnable_config(config) + + assert configuration.interrupt_before_tools == ["db_tool"] + assert configuration.max_step_num == 5 + + def test_tool_interceptor_initialization_logging(self): + """Test that ToolInterceptor initialization is logged.""" + with patch("src.agents.tool_interceptor.logger") as mock_logger: + interceptor = ToolInterceptor(["tool_a", "tool_b"]) + mock_logger.info.assert_called() + + def test_wrap_tools_with_interceptor_logging(self): + """Test that tool wrapping is logged.""" + @tool + def test_tool(x: str) -> str: + """Test.""" + return x + + with patch("src.agents.tool_interceptor.logger") as mock_logger: + wrapped = wrap_tools_with_interceptor([test_tool], ["test_tool"]) + # Check that at least one info log was called + assert mock_logger.info.called or mock_logger.debug.called + + def test_interrupt_resolution_with_empty_feedback(self): + """Test interrupt resolution with empty feedback.""" + @tool + def test_tool(x: str) -> str: + """Test.""" + return f"Result: {x}" + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = "" + + interceptor = ToolInterceptor(["test_tool"]) + wrapped = ToolInterceptor.wrap_tool(test_tool, interceptor) + + result = wrapped.invoke("test") + + # Empty feedback should be treated as rejection + assert isinstance(result, dict) + assert result["status"] == "rejected" + + def test_interrupt_resolution_with_none_feedback(self): + """Test interrupt resolution with None feedback.""" + @tool + def test_tool(x: str) -> str: + """Test.""" + return f"Result: {x}" + + with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + mock_interrupt.return_value = None + + interceptor = ToolInterceptor(["test_tool"]) + wrapped = ToolInterceptor.wrap_tool(test_tool, interceptor) + + result = wrapped.invoke("test") + + # None feedback should be treated as rejection + assert isinstance(result, dict) + assert result["status"] == "rejected" diff --git a/tests/unit/agents/test_tool_interceptor.py b/tests/unit/agents/test_tool_interceptor.py new file mode 100644 index 0000000..8fa37df --- /dev/null +++ b/tests/unit/agents/test_tool_interceptor.py @@ -0,0 +1,433 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +from unittest.mock import Mock, patch, MagicMock, AsyncMock +from langchain_core.tools import tool, BaseTool + +from src.agents.tool_interceptor import ( + ToolInterceptor, + wrap_tools_with_interceptor, +) + + +class TestToolInterceptor: + """Tests for ToolInterceptor class.""" + + def test_init_with_tools(self): + """Test initializing interceptor with tool list.""" + tools = ["db_tool", "api_tool"] + interceptor = ToolInterceptor(tools) + assert interceptor.interrupt_before_tools == tools + + def test_init_without_tools(self): + """Test initializing interceptor without tools.""" + interceptor = ToolInterceptor() + assert interceptor.interrupt_before_tools == [] + + def test_should_interrupt_with_matching_tool(self): + """Test should_interrupt returns True for matching tools.""" + tools = ["db_tool", "api_tool"] + interceptor = ToolInterceptor(tools) + assert interceptor.should_interrupt("db_tool") is True + assert interceptor.should_interrupt("api_tool") is True + + def test_should_interrupt_with_non_matching_tool(self): + """Test should_interrupt returns False for non-matching tools.""" + tools = ["db_tool", "api_tool"] + interceptor = ToolInterceptor(tools) + assert interceptor.should_interrupt("search_tool") is False + assert interceptor.should_interrupt("crawl_tool") is False + + def test_should_interrupt_empty_list(self): + """Test should_interrupt with empty interrupt list.""" + interceptor = ToolInterceptor([]) + assert interceptor.should_interrupt("db_tool") is False + + def test_parse_approval_with_approval_keywords(self): + """Test parsing user feedback with approval keywords.""" + assert ToolInterceptor._parse_approval("approved") is True + assert ToolInterceptor._parse_approval("approve") is True + assert ToolInterceptor._parse_approval("yes") is True + assert ToolInterceptor._parse_approval("proceed") is True + assert ToolInterceptor._parse_approval("continue") is True + assert ToolInterceptor._parse_approval("ok") is True + assert ToolInterceptor._parse_approval("okay") is True + assert ToolInterceptor._parse_approval("accepted") is True + assert ToolInterceptor._parse_approval("accept") is True + assert ToolInterceptor._parse_approval("[approved]") is True + + def test_parse_approval_case_insensitive(self): + """Test parsing is case-insensitive.""" + assert ToolInterceptor._parse_approval("APPROVED") is True + assert ToolInterceptor._parse_approval("Approved") is True + assert ToolInterceptor._parse_approval("PROCEED") is True + + def test_parse_approval_with_surrounding_text(self): + """Test parsing with surrounding text.""" + assert ToolInterceptor._parse_approval("Sure, proceed with the tool") is True + assert ToolInterceptor._parse_approval("[ACCEPTED] I approve this") is True + + def test_parse_approval_rejection(self): + """Test parsing rejects non-approval feedback.""" + assert ToolInterceptor._parse_approval("no") is False + assert ToolInterceptor._parse_approval("reject") is False + assert ToolInterceptor._parse_approval("cancel") is False + assert ToolInterceptor._parse_approval("random feedback") is False + + def test_parse_approval_empty_string(self): + """Test parsing empty string.""" + assert ToolInterceptor._parse_approval("") is False + + def test_parse_approval_none(self): + """Test parsing None.""" + assert ToolInterceptor._parse_approval(None) is False + + @patch("src.agents.tool_interceptor.interrupt") + def test_wrap_tool_with_interrupt(self, mock_interrupt): + """Test wrapping a tool with interrupt.""" + mock_interrupt.return_value = "approved" + + # Create a simple test tool + @tool + def test_tool(input_text: str) -> str: + """Test tool.""" + return f"Result: {input_text}" + + interceptor = ToolInterceptor(["test_tool"]) + + # Wrap the tool + wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor) + + # Invoke the wrapped tool + result = wrapped_tool.invoke("hello") + + # Verify interrupt was called + mock_interrupt.assert_called_once() + assert "test_tool" in mock_interrupt.call_args[0][0] + + @patch("src.agents.tool_interceptor.interrupt") + def test_wrap_tool_without_interrupt(self, mock_interrupt): + """Test wrapping a tool that doesn't trigger interrupt.""" + # Create a simple test tool + @tool + def test_tool(input_text: str) -> str: + """Test tool.""" + return f"Result: {input_text}" + + interceptor = ToolInterceptor(["other_tool"]) + + # Wrap the tool + wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor) + + # Invoke the wrapped tool + result = wrapped_tool.invoke("hello") + + # Verify interrupt was NOT called + mock_interrupt.assert_not_called() + assert "Result: hello" in str(result) + + @patch("src.agents.tool_interceptor.interrupt") + def test_wrap_tool_user_rejects(self, mock_interrupt): + """Test user rejecting tool execution.""" + mock_interrupt.return_value = "no" + + @tool + def test_tool(input_text: str) -> str: + """Test tool.""" + return f"Result: {input_text}" + + interceptor = ToolInterceptor(["test_tool"]) + wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor) + + # Invoke the wrapped tool + result = wrapped_tool.invoke("hello") + + # Verify tool was not executed + assert isinstance(result, dict) + assert "error" in result + assert result["status"] == "rejected" + + def test_wrap_tools_with_interceptor_empty_list(self): + """Test wrapping tools with empty interrupt list.""" + @tool + def test_tool(input_text: str) -> str: + """Test tool.""" + return f"Result: {input_text}" + + tools = [test_tool] + wrapped_tools = wrap_tools_with_interceptor(tools, []) + + # Should return tools as-is + assert len(wrapped_tools) == 1 + assert wrapped_tools[0].name == "test_tool" + + def test_wrap_tools_with_interceptor_none(self): + """Test wrapping tools with None interrupt list.""" + @tool + def test_tool(input_text: str) -> str: + """Test tool.""" + return f"Result: {input_text}" + + tools = [test_tool] + wrapped_tools = wrap_tools_with_interceptor(tools, None) + + # Should return tools as-is + assert len(wrapped_tools) == 1 + + @patch("src.agents.tool_interceptor.interrupt") + def test_wrap_tools_with_interceptor_multiple(self, mock_interrupt): + """Test wrapping multiple tools.""" + mock_interrupt.return_value = "approved" + + @tool + def db_tool(query: str) -> str: + """DB tool.""" + return f"Query result: {query}" + + @tool + def search_tool(query: str) -> str: + """Search tool.""" + return f"Search result: {query}" + + tools = [db_tool, search_tool] + wrapped_tools = wrap_tools_with_interceptor(tools, ["db_tool"]) + + # Only db_tool should trigger interrupt + db_result = wrapped_tools[0].invoke("test query") + assert mock_interrupt.call_count == 1 + + search_result = wrapped_tools[1].invoke("test query") + # No additional interrupt calls for search_tool + assert mock_interrupt.call_count == 1 + + def test_wrap_tool_preserves_tool_properties(self): + """Test that wrapping preserves tool properties.""" + @tool + def my_tool(input_text: str) -> str: + """My tool description.""" + return f"Result: {input_text}" + + interceptor = ToolInterceptor([]) + wrapped_tool = ToolInterceptor.wrap_tool(my_tool, interceptor) + + assert wrapped_tool.name == "my_tool" + assert wrapped_tool.description == "My tool description." + + +class TestFormatToolInput: + """Tests for tool input formatting functionality.""" + + def test_format_tool_input_none(self): + """Test formatting None input.""" + result = ToolInterceptor._format_tool_input(None) + assert result == "No input" + + def test_format_tool_input_string(self): + """Test formatting string input.""" + input_str = "SELECT * FROM users" + result = ToolInterceptor._format_tool_input(input_str) + assert result == input_str + + def test_format_tool_input_simple_dict(self): + """Test formatting simple dictionary.""" + input_dict = {"query": "test", "limit": 10} + result = ToolInterceptor._format_tool_input(input_dict) + + # Should be valid JSON + import json + parsed = json.loads(result) + assert parsed == input_dict + # Should be indented + assert "\n" in result + + def test_format_tool_input_nested_dict(self): + """Test formatting nested dictionary.""" + input_dict = { + "query": "SELECT * FROM users", + "config": { + "timeout": 30, + "retry": True + } + } + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + assert parsed == input_dict + assert "timeout" in result + assert "retry" in result + + def test_format_tool_input_list(self): + """Test formatting list input.""" + input_list = ["item1", "item2", 123] + result = ToolInterceptor._format_tool_input(input_list) + + import json + parsed = json.loads(result) + assert parsed == input_list + + def test_format_tool_input_complex_list(self): + """Test formatting list with mixed types.""" + input_list = ["text", 42, 3.14, True, {"key": "value"}] + result = ToolInterceptor._format_tool_input(input_list) + + import json + parsed = json.loads(result) + assert parsed == input_list + + def test_format_tool_input_tuple(self): + """Test formatting tuple input.""" + input_tuple = ("item1", "item2", 123) + result = ToolInterceptor._format_tool_input(input_tuple) + + import json + parsed = json.loads(result) + # JSON converts tuples to lists + assert parsed == list(input_tuple) + + def test_format_tool_input_integer(self): + """Test formatting integer input.""" + result = ToolInterceptor._format_tool_input(42) + assert result == "42" + + def test_format_tool_input_float(self): + """Test formatting float input.""" + result = ToolInterceptor._format_tool_input(3.14) + assert result == "3.14" + + def test_format_tool_input_boolean(self): + """Test formatting boolean input.""" + result_true = ToolInterceptor._format_tool_input(True) + result_false = ToolInterceptor._format_tool_input(False) + assert result_true == "True" + assert result_false == "False" + + def test_format_tool_input_deeply_nested(self): + """Test formatting deeply nested structure.""" + input_dict = { + "level1": { + "level2": { + "level3": { + "level4": ["a", "b", "c"], + "data": {"key": "value"} + } + } + } + } + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + assert parsed == input_dict + + def test_format_tool_input_empty_dict(self): + """Test formatting empty dictionary.""" + result = ToolInterceptor._format_tool_input({}) + assert result == "{}" + + def test_format_tool_input_empty_list(self): + """Test formatting empty list.""" + result = ToolInterceptor._format_tool_input([]) + assert result == "[]" + + def test_format_tool_input_special_characters(self): + """Test formatting dict with special characters.""" + input_dict = { + "query": 'SELECT * FROM users WHERE name = "John"', + "path": "/usr/local/bin", + "unicode": "你好世界" + } + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + assert parsed == input_dict + + def test_format_tool_input_numbers_as_strings(self): + """Test formatting with various number types.""" + input_dict = { + "int": 42, + "float": 3.14159, + "negative": -100, + "zero": 0, + "scientific": 1e-5 + } + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + assert parsed["int"] == 42 + assert abs(parsed["float"] - 3.14159) < 0.00001 + assert parsed["negative"] == -100 + assert parsed["zero"] == 0 + + def test_format_tool_input_with_none_values(self): + """Test formatting dict with None values.""" + input_dict = { + "key1": "value1", + "key2": None, + "key3": {"nested": None} + } + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + assert parsed == input_dict + + def test_format_tool_input_indentation(self): + """Test that output uses proper indentation (2 spaces).""" + input_dict = {"outer": {"inner": "value"}} + result = ToolInterceptor._format_tool_input(input_dict) + + # Should have indented lines + assert " " in result # 2-space indentation + lines = result.split("\n") + # Check that indentation increases with nesting + assert any(line.startswith(" ") for line in lines) + + def test_format_tool_input_preserves_order_insertion(self): + """Test that dict order is preserved in output.""" + input_dict = { + "first": 1, + "second": 2, + "third": 3 + } + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + # Verify all keys are present + assert set(parsed.keys()) == {"first", "second", "third"} + + def test_format_tool_input_long_strings(self): + """Test formatting with long string values.""" + long_string = "x" * 1000 + input_dict = {"long": long_string} + result = ToolInterceptor._format_tool_input(input_dict) + + import json + parsed = json.loads(result) + assert parsed["long"] == long_string + + def test_format_tool_input_mixed_types_in_list(self): + """Test formatting list with mixed complex types.""" + input_list = [ + "string", + 42, + {"dict": "value"}, + [1, 2, 3], + True, + None + ] + result = ToolInterceptor._format_tool_input(input_list) + + import json + parsed = json.loads(result) + assert len(parsed) == 6 + assert parsed[0] == "string" + assert parsed[1] == 42 + assert parsed[2] == {"dict": "value"} + assert parsed[3] == [1, 2, 3] + assert parsed[4] is True + assert parsed[5] is None