mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 22:32:12 +08:00
* security: add log injection attack prevention with input sanitization - Created src/utils/log_sanitizer.py to sanitize user-controlled input before logging - Prevents log injection attacks using newlines, tabs, carriage returns, etc. - Escapes dangerous characters: \n, \r, \t, \0, \x1b - Provides specialized functions for different input types: - sanitize_log_input: general purpose sanitization - sanitize_thread_id: for user-provided thread IDs - sanitize_user_content: for user messages (more aggressive truncation) - sanitize_agent_name: for agent identifiers - sanitize_tool_name: for tool names - sanitize_feedback: for user interrupt feedback - create_safe_log_message: template-based safe message creation - Updated src/server/app.py to sanitize all user input in logging: - Thread IDs from request parameter - Message content from user - Agent names and node information - Tool names and feedback - Updated src/agents/tool_interceptor.py to sanitize: - Tool names during execution - User feedback during interrupt handling - Tool input data - Added 29 comprehensive unit tests covering: - Classic newline injection attacks - Carriage return injection - Tab and null character injection - HTML/ANSI escape sequence injection - Combined multi-character attacks - Truncation and length limits Fixes potential log forgery vulnerability where malicious users could inject fake log entries via unsanitized input containing control characters.
435 lines
15 KiB
Python
435 lines
15 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
from langchain_core.tools import BaseTool, tool
|
|
|
|
from src.agents.tool_interceptor import (
|
|
ToolInterceptor,
|
|
wrap_tools_with_interceptor,
|
|
)
|
|
|
|
|
|
class TestToolInterceptor:
|
|
"""Tests for ToolInterceptor class."""
|
|
|
|
def test_init_with_tools(self):
|
|
"""Test initializing interceptor with tool list."""
|
|
tools = ["db_tool", "api_tool"]
|
|
interceptor = ToolInterceptor(tools)
|
|
assert interceptor.interrupt_before_tools == tools
|
|
|
|
def test_init_without_tools(self):
|
|
"""Test initializing interceptor without tools."""
|
|
interceptor = ToolInterceptor()
|
|
assert interceptor.interrupt_before_tools == []
|
|
|
|
def test_should_interrupt_with_matching_tool(self):
|
|
"""Test should_interrupt returns True for matching tools."""
|
|
tools = ["db_tool", "api_tool"]
|
|
interceptor = ToolInterceptor(tools)
|
|
assert interceptor.should_interrupt("db_tool") is True
|
|
assert interceptor.should_interrupt("api_tool") is True
|
|
|
|
def test_should_interrupt_with_non_matching_tool(self):
|
|
"""Test should_interrupt returns False for non-matching tools."""
|
|
tools = ["db_tool", "api_tool"]
|
|
interceptor = ToolInterceptor(tools)
|
|
assert interceptor.should_interrupt("search_tool") is False
|
|
assert interceptor.should_interrupt("crawl_tool") is False
|
|
|
|
def test_should_interrupt_empty_list(self):
|
|
"""Test should_interrupt with empty interrupt list."""
|
|
interceptor = ToolInterceptor([])
|
|
assert interceptor.should_interrupt("db_tool") is False
|
|
|
|
def test_parse_approval_with_approval_keywords(self):
|
|
"""Test parsing user feedback with approval keywords."""
|
|
assert ToolInterceptor._parse_approval("approved") is True
|
|
assert ToolInterceptor._parse_approval("approve") is True
|
|
assert ToolInterceptor._parse_approval("yes") is True
|
|
assert ToolInterceptor._parse_approval("proceed") is True
|
|
assert ToolInterceptor._parse_approval("continue") is True
|
|
assert ToolInterceptor._parse_approval("ok") is True
|
|
assert ToolInterceptor._parse_approval("okay") is True
|
|
assert ToolInterceptor._parse_approval("accepted") is True
|
|
assert ToolInterceptor._parse_approval("accept") is True
|
|
assert ToolInterceptor._parse_approval("[approved]") is True
|
|
|
|
def test_parse_approval_case_insensitive(self):
|
|
"""Test parsing is case-insensitive."""
|
|
assert ToolInterceptor._parse_approval("APPROVED") is True
|
|
assert ToolInterceptor._parse_approval("Approved") is True
|
|
assert ToolInterceptor._parse_approval("PROCEED") is True
|
|
|
|
def test_parse_approval_with_surrounding_text(self):
|
|
"""Test parsing with surrounding text."""
|
|
assert ToolInterceptor._parse_approval("Sure, proceed with the tool") is True
|
|
assert ToolInterceptor._parse_approval("[ACCEPTED] I approve this") is True
|
|
|
|
def test_parse_approval_rejection(self):
|
|
"""Test parsing rejects non-approval feedback."""
|
|
assert ToolInterceptor._parse_approval("no") is False
|
|
assert ToolInterceptor._parse_approval("reject") is False
|
|
assert ToolInterceptor._parse_approval("cancel") is False
|
|
assert ToolInterceptor._parse_approval("random feedback") is False
|
|
|
|
def test_parse_approval_empty_string(self):
|
|
"""Test parsing empty string."""
|
|
assert ToolInterceptor._parse_approval("") is False
|
|
|
|
def test_parse_approval_none(self):
|
|
"""Test parsing None."""
|
|
assert ToolInterceptor._parse_approval(None) is False
|
|
|
|
@patch("src.agents.tool_interceptor.interrupt")
|
|
def test_wrap_tool_with_interrupt(self, mock_interrupt):
|
|
"""Test wrapping a tool with interrupt."""
|
|
mock_interrupt.return_value = "approved"
|
|
|
|
# Create a simple test tool
|
|
@tool
|
|
def test_tool(input_text: str) -> str:
|
|
"""Test tool."""
|
|
return f"Result: {input_text}"
|
|
|
|
interceptor = ToolInterceptor(["test_tool"])
|
|
|
|
# Wrap the tool
|
|
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
|
|
|
# Invoke the wrapped tool
|
|
result = wrapped_tool.invoke("hello")
|
|
|
|
# Verify interrupt was called
|
|
mock_interrupt.assert_called_once()
|
|
assert "test_tool" in mock_interrupt.call_args[0][0]
|
|
|
|
@patch("src.agents.tool_interceptor.interrupt")
|
|
def test_wrap_tool_without_interrupt(self, mock_interrupt):
|
|
"""Test wrapping a tool that doesn't trigger interrupt."""
|
|
# Create a simple test tool
|
|
@tool
|
|
def test_tool(input_text: str) -> str:
|
|
"""Test tool."""
|
|
return f"Result: {input_text}"
|
|
|
|
interceptor = ToolInterceptor(["other_tool"])
|
|
|
|
# Wrap the tool
|
|
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
|
|
|
# Invoke the wrapped tool
|
|
result = wrapped_tool.invoke("hello")
|
|
|
|
# Verify interrupt was NOT called
|
|
mock_interrupt.assert_not_called()
|
|
assert "Result: hello" in str(result)
|
|
|
|
@patch("src.agents.tool_interceptor.interrupt")
|
|
def test_wrap_tool_user_rejects(self, mock_interrupt):
|
|
"""Test user rejecting tool execution."""
|
|
mock_interrupt.return_value = "no"
|
|
|
|
@tool
|
|
def test_tool(input_text: str) -> str:
|
|
"""Test tool."""
|
|
return f"Result: {input_text}"
|
|
|
|
interceptor = ToolInterceptor(["test_tool"])
|
|
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
|
|
|
# Invoke the wrapped tool
|
|
result = wrapped_tool.invoke("hello")
|
|
|
|
# Verify tool was not executed
|
|
assert isinstance(result, dict)
|
|
assert "error" in result
|
|
assert result["status"] == "rejected"
|
|
|
|
def test_wrap_tools_with_interceptor_empty_list(self):
|
|
"""Test wrapping tools with empty interrupt list."""
|
|
@tool
|
|
def test_tool(input_text: str) -> str:
|
|
"""Test tool."""
|
|
return f"Result: {input_text}"
|
|
|
|
tools = [test_tool]
|
|
wrapped_tools = wrap_tools_with_interceptor(tools, [])
|
|
|
|
# Should return tools as-is
|
|
assert len(wrapped_tools) == 1
|
|
assert wrapped_tools[0].name == "test_tool"
|
|
|
|
def test_wrap_tools_with_interceptor_none(self):
|
|
"""Test wrapping tools with None interrupt list."""
|
|
@tool
|
|
def test_tool(input_text: str) -> str:
|
|
"""Test tool."""
|
|
return f"Result: {input_text}"
|
|
|
|
tools = [test_tool]
|
|
wrapped_tools = wrap_tools_with_interceptor(tools, None)
|
|
|
|
# Should return tools as-is
|
|
assert len(wrapped_tools) == 1
|
|
|
|
@patch("src.agents.tool_interceptor.interrupt")
|
|
def test_wrap_tools_with_interceptor_multiple(self, mock_interrupt):
|
|
"""Test wrapping multiple tools."""
|
|
mock_interrupt.return_value = "approved"
|
|
|
|
@tool
|
|
def db_tool(query: str) -> str:
|
|
"""DB tool."""
|
|
return f"Query result: {query}"
|
|
|
|
@tool
|
|
def search_tool(query: str) -> str:
|
|
"""Search tool."""
|
|
return f"Search result: {query}"
|
|
|
|
tools = [db_tool, search_tool]
|
|
wrapped_tools = wrap_tools_with_interceptor(tools, ["db_tool"])
|
|
|
|
# Only db_tool should trigger interrupt
|
|
db_result = wrapped_tools[0].invoke("test query")
|
|
assert mock_interrupt.call_count == 1
|
|
|
|
search_result = wrapped_tools[1].invoke("test query")
|
|
# No additional interrupt calls for search_tool
|
|
assert mock_interrupt.call_count == 1
|
|
|
|
def test_wrap_tool_preserves_tool_properties(self):
|
|
"""Test that wrapping preserves tool properties."""
|
|
@tool
|
|
def my_tool(input_text: str) -> str:
|
|
"""My tool description."""
|
|
return f"Result: {input_text}"
|
|
|
|
interceptor = ToolInterceptor([])
|
|
wrapped_tool = ToolInterceptor.wrap_tool(my_tool, interceptor)
|
|
|
|
assert wrapped_tool.name == "my_tool"
|
|
assert wrapped_tool.description == "My tool description."
|
|
|
|
|
|
class TestFormatToolInput:
|
|
"""Tests for tool input formatting functionality."""
|
|
|
|
def test_format_tool_input_none(self):
|
|
"""Test formatting None input."""
|
|
result = ToolInterceptor._format_tool_input(None)
|
|
assert result == "No input"
|
|
|
|
def test_format_tool_input_string(self):
|
|
"""Test formatting string input."""
|
|
input_str = "SELECT * FROM users"
|
|
result = ToolInterceptor._format_tool_input(input_str)
|
|
assert result == input_str
|
|
|
|
def test_format_tool_input_simple_dict(self):
|
|
"""Test formatting simple dictionary."""
|
|
input_dict = {"query": "test", "limit": 10}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
# Should be valid JSON
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_dict
|
|
# Should be indented
|
|
assert "\n" in result
|
|
|
|
def test_format_tool_input_nested_dict(self):
|
|
"""Test formatting nested dictionary."""
|
|
input_dict = {
|
|
"query": "SELECT * FROM users",
|
|
"config": {
|
|
"timeout": 30,
|
|
"retry": True
|
|
}
|
|
}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_dict
|
|
assert "timeout" in result
|
|
assert "retry" in result
|
|
|
|
def test_format_tool_input_list(self):
|
|
"""Test formatting list input."""
|
|
input_list = ["item1", "item2", 123]
|
|
result = ToolInterceptor._format_tool_input(input_list)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_list
|
|
|
|
def test_format_tool_input_complex_list(self):
|
|
"""Test formatting list with mixed types."""
|
|
input_list = ["text", 42, 3.14, True, {"key": "value"}]
|
|
result = ToolInterceptor._format_tool_input(input_list)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_list
|
|
|
|
def test_format_tool_input_tuple(self):
|
|
"""Test formatting tuple input."""
|
|
input_tuple = ("item1", "item2", 123)
|
|
result = ToolInterceptor._format_tool_input(input_tuple)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
# JSON converts tuples to lists
|
|
assert parsed == list(input_tuple)
|
|
|
|
def test_format_tool_input_integer(self):
|
|
"""Test formatting integer input."""
|
|
result = ToolInterceptor._format_tool_input(42)
|
|
assert result == "42"
|
|
|
|
def test_format_tool_input_float(self):
|
|
"""Test formatting float input."""
|
|
result = ToolInterceptor._format_tool_input(3.14)
|
|
assert result == "3.14"
|
|
|
|
def test_format_tool_input_boolean(self):
|
|
"""Test formatting boolean input."""
|
|
result_true = ToolInterceptor._format_tool_input(True)
|
|
result_false = ToolInterceptor._format_tool_input(False)
|
|
assert result_true == "True"
|
|
assert result_false == "False"
|
|
|
|
def test_format_tool_input_deeply_nested(self):
|
|
"""Test formatting deeply nested structure."""
|
|
input_dict = {
|
|
"level1": {
|
|
"level2": {
|
|
"level3": {
|
|
"level4": ["a", "b", "c"],
|
|
"data": {"key": "value"}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_dict
|
|
|
|
def test_format_tool_input_empty_dict(self):
|
|
"""Test formatting empty dictionary."""
|
|
result = ToolInterceptor._format_tool_input({})
|
|
assert result == "{}"
|
|
|
|
def test_format_tool_input_empty_list(self):
|
|
"""Test formatting empty list."""
|
|
result = ToolInterceptor._format_tool_input([])
|
|
assert result == "[]"
|
|
|
|
def test_format_tool_input_special_characters(self):
|
|
"""Test formatting dict with special characters."""
|
|
input_dict = {
|
|
"query": 'SELECT * FROM users WHERE name = "John"',
|
|
"path": "/usr/local/bin",
|
|
"unicode": "你好世界"
|
|
}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_dict
|
|
|
|
def test_format_tool_input_numbers_as_strings(self):
|
|
"""Test formatting with various number types."""
|
|
input_dict = {
|
|
"int": 42,
|
|
"float": 3.14159,
|
|
"negative": -100,
|
|
"zero": 0,
|
|
"scientific": 1e-5
|
|
}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed["int"] == 42
|
|
assert abs(parsed["float"] - 3.14159) < 0.00001
|
|
assert parsed["negative"] == -100
|
|
assert parsed["zero"] == 0
|
|
|
|
def test_format_tool_input_with_none_values(self):
|
|
"""Test formatting dict with None values."""
|
|
input_dict = {
|
|
"key1": "value1",
|
|
"key2": None,
|
|
"key3": {"nested": None}
|
|
}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed == input_dict
|
|
|
|
def test_format_tool_input_indentation(self):
|
|
"""Test that output uses proper indentation (2 spaces)."""
|
|
input_dict = {"outer": {"inner": "value"}}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
# Should have indented lines
|
|
assert " " in result # 2-space indentation
|
|
lines = result.split("\n")
|
|
# Check that indentation increases with nesting
|
|
assert any(line.startswith(" ") for line in lines)
|
|
|
|
def test_format_tool_input_preserves_order_insertion(self):
|
|
"""Test that dict order is preserved in output."""
|
|
input_dict = {
|
|
"first": 1,
|
|
"second": 2,
|
|
"third": 3
|
|
}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
# Verify all keys are present
|
|
assert set(parsed.keys()) == {"first", "second", "third"}
|
|
|
|
def test_format_tool_input_long_strings(self):
|
|
"""Test formatting with long string values."""
|
|
long_string = "x" * 1000
|
|
input_dict = {"long": long_string}
|
|
result = ToolInterceptor._format_tool_input(input_dict)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert parsed["long"] == long_string
|
|
|
|
def test_format_tool_input_mixed_types_in_list(self):
|
|
"""Test formatting list with mixed complex types."""
|
|
input_list = [
|
|
"string",
|
|
42,
|
|
{"dict": "value"},
|
|
[1, 2, 3],
|
|
True,
|
|
None
|
|
]
|
|
result = ToolInterceptor._format_tool_input(input_list)
|
|
|
|
import json
|
|
parsed = json.loads(result)
|
|
assert len(parsed) == 6
|
|
assert parsed[0] == "string"
|
|
assert parsed[1] == 42
|
|
assert parsed[2] == {"dict": "value"}
|
|
assert parsed[3] == [1, 2, 3]
|
|
assert parsed[4] is True
|
|
assert parsed[5] is None
|