diff --git a/src/agents/tool_interceptor.py b/src/agents/tool_interceptor.py index 5b7e459..84b47a2 100644 --- a/src/agents/tool_interceptor.py +++ b/src/agents/tool_interceptor.py @@ -159,6 +159,13 @@ class ToolInterceptor: # Use object.__setattr__ to bypass Pydantic validation logger.debug(f"Attaching intercepted function to tool '{safe_tool_name}'") object.__setattr__(tool, "func", intercepted_func) + + # Also ensure the tool's _run method is updated if it exists + if hasattr(tool, '_run'): + logger.debug(f"Also wrapping _run method for tool '{safe_tool_name}'") + # Wrap _run to ensure interception is applied regardless of invocation method + object.__setattr__(tool, "_run", intercepted_func) + return tool @staticmethod diff --git a/tests/unit/agents/test_tool_interceptor_fix.py b/tests/unit/agents/test_tool_interceptor_fix.py new file mode 100644 index 0000000..a87ff4a --- /dev/null +++ b/tests/unit/agents/test_tool_interceptor_fix.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import unittest +from unittest.mock import MagicMock +from langchain_core.tools import Tool +from src.agents.tool_interceptor import ToolInterceptor + +class TestToolInterceptorFix(unittest.TestCase): + def test_interceptor_patches_run_method(self): + # Create a mock tool + mock_func = MagicMock(return_value="Original Result") + tool = Tool(name="resolve_company_name", func=mock_func, description="test tool") + + # Interceptor that always interrupts 'resolve_company_name' + interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"]) + + # Wrap the tool + wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor) + + # Mock interrupt to avoid actual suspension + with unittest.mock.patch("src.agents.tool_interceptor.interrupt", return_value="approved"): + # Call using .run() which triggers ._run() + # Standard BaseTool execution flow is invoke -> run -> _run + # If we only patched func, run() would call original _run which calls original func, bypassing interception + # With the fix, _run should be patched to call intercepted_func + result = wrapped_tool.run("some input") + + # Verify result + self.assertEqual(result, "Original Result") + + # Verify the original function was called + # If interception works, intercepted_func calls original_func + mock_func.assert_called_once() + + def test_run_method_without_interrupt(self): + """Test that tools not in interrupt list work normally via .run()""" + mock_func = MagicMock(return_value="Result") + tool = Tool(name="other_tool", func=mock_func, description="test") + + interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"]) + wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor) + + with unittest.mock.patch("src.agents.tool_interceptor.interrupt") as mock_interrupt: + result = wrapped_tool.run("input") + + # Verify interrupt was NOT called for non-intercepted tool + mock_interrupt.assert_not_called() + assert result == "Result" + mock_func.assert_called_once() + + def test_interceptor_resolve_company_name_example(self): + """Test specific resolve_company_name logic capability using interceptor subclassing or custom logic simulation.""" + # This test verifies that we can intercept execution of resolve_company_name + # even if it's called via .run() + + mock_func = MagicMock(return_value='{"code": 0, "data": [{"companyName": "A"}, {"companyName": "B"}]}') + tool = Tool(name="resolve_company_name", func=mock_func, description="resolve company") + + interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"]) + wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor) + + # Simulate user selecting "B" + with unittest.mock.patch("src.agents.tool_interceptor.interrupt", return_value="approved"): + # We are not testing the complex business logic here because we didn't add it to ToolInterceptor class + # We are mostly verifying that the INTERCEPTION mechanism works for this tool name when called via .run() + wrapped_tool.run("query") + + mock_func.assert_called_once()