From b4c09aa4b1cc0f0edb8b20d876daf38877c6d36c Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Mon, 27 Oct 2025 20:57:23 +0800 Subject: [PATCH] security: add log injection attack prevention with input sanitization (#667) * security: add log injection attack prevention with input sanitization - Created src/utils/log_sanitizer.py to sanitize user-controlled input before logging - Prevents log injection attacks using newlines, tabs, carriage returns, etc. - Escapes dangerous characters: \n, \r, \t, \0, \x1b - Provides specialized functions for different input types: - sanitize_log_input: general purpose sanitization - sanitize_thread_id: for user-provided thread IDs - sanitize_user_content: for user messages (more aggressive truncation) - sanitize_agent_name: for agent identifiers - sanitize_tool_name: for tool names - sanitize_feedback: for user interrupt feedback - create_safe_log_message: template-based safe message creation - Updated src/server/app.py to sanitize all user input in logging: - Thread IDs from request parameter - Message content from user - Agent names and node information - Tool names and feedback - Updated src/agents/tool_interceptor.py to sanitize: - Tool names during execution - User feedback during interrupt handling - Tool input data - Added 29 comprehensive unit tests covering: - Classic newline injection attacks - Carriage return injection - Tab and null character injection - HTML/ANSI escape sequence injection - Combined multi-character attacks - Truncation and length limits Fixes potential log forgery vulnerability where malicious users could inject fake log entries via unsanitized input containing control characters. --- src/agents/agents.py | 2 +- src/agents/tool_interceptor.py | 43 +-- src/crawler/readability_extractor.py | 1 + src/server/app.py | 133 +++++---- src/utils/log_sanitizer.py | 186 ++++++++++++ .../test_tool_interceptor_integration.py | 6 +- tests/unit/agents/test_tool_interceptor.py | 5 +- tests/unit/crawler/test_jina_client.py | 4 +- .../crawler/test_readability_extractor.py | 1 + tests/unit/graph/test_plan_validation.py | 3 +- tests/unit/server/test_tool_call_chunks.py | 7 +- tests/unit/utils/test_json_utils.py | 6 +- tests/unit/utils/test_log_sanitizer.py | 268 ++++++++++++++++++ 13 files changed, 585 insertions(+), 80 deletions(-) create mode 100644 src/utils/log_sanitizer.py create mode 100644 tests/unit/utils/test_log_sanitizer.py diff --git a/src/agents/agents.py b/src/agents/agents.py index e1436e5..df94bd3 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -6,10 +6,10 @@ from typing import List, Optional from langgraph.prebuilt import create_react_agent +from src.agents.tool_interceptor import wrap_tools_with_interceptor from src.config.agents import AGENT_LLM_MAP from src.llms.llm import get_llm_by_type from src.prompts import apply_prompt_template -from src.agents.tool_interceptor import wrap_tools_with_interceptor logger = logging.getLogger(__name__) diff --git a/src/agents/tool_interceptor.py b/src/agents/tool_interceptor.py index b6d1d2b..5b7e459 100644 --- a/src/agents/tool_interceptor.py +++ b/src/agents/tool_interceptor.py @@ -8,6 +8,12 @@ from typing import Any, Callable, List, Optional from langchain_core.tools import BaseTool from langgraph.types import interrupt +from src.utils.log_sanitizer import ( + sanitize_feedback, + sanitize_log_input, + sanitize_tool_name, +) + logger = logging.getLogger(__name__) @@ -84,27 +90,30 @@ class ToolInterceptor: BaseTool: The wrapped tool with interrupt capability """ original_func = tool.func - logger.debug(f"Wrapping tool '{tool.name}' with interrupt capability") + safe_tool_name = sanitize_tool_name(tool.name) + logger.debug(f"Wrapping tool '{safe_tool_name}' with interrupt capability") def intercepted_func(*args: Any, **kwargs: Any) -> Any: """Execute the tool with interrupt check.""" tool_name = tool.name - logger.debug(f"[ToolInterceptor] Executing tool: {tool_name}") + safe_tool_name_local = sanitize_tool_name(tool_name) + logger.debug(f"[ToolInterceptor] Executing tool: {safe_tool_name_local}") # Format tool input for display tool_input = args[0] if args else kwargs tool_input_repr = ToolInterceptor._format_tool_input(tool_input) - logger.debug(f"[ToolInterceptor] Tool input: {tool_input_repr[:200]}") + safe_tool_input = sanitize_log_input(tool_input_repr, max_length=100) + logger.debug(f"[ToolInterceptor] Tool input: {safe_tool_input}") should_interrupt = interceptor.should_interrupt(tool_name) - logger.debug(f"[ToolInterceptor] should_interrupt={should_interrupt} for tool '{tool_name}'") + logger.debug(f"[ToolInterceptor] should_interrupt={should_interrupt} for tool '{safe_tool_name_local}'") if should_interrupt: logger.info( - f"[ToolInterceptor] Interrupting before tool '{tool_name}'" + f"[ToolInterceptor] Interrupting before tool '{safe_tool_name_local}'" ) logger.debug( - f"[ToolInterceptor] Interrupt message: About to execute tool '{tool_name}' with input: {tool_input_repr[:100]}..." + f"[ToolInterceptor] Interrupt message: About to execute tool '{safe_tool_name_local}' with input: {safe_tool_input}..." ) # Trigger interrupt and wait for user feedback @@ -112,41 +121,43 @@ class ToolInterceptor: feedback = interrupt( f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?" ) - logger.debug(f"[ToolInterceptor] Interrupt returned with feedback: {f'{feedback[:100]}...' if feedback and len(feedback) > 100 else feedback if feedback else 'None'}") + safe_feedback = sanitize_feedback(feedback) + logger.debug(f"[ToolInterceptor] Interrupt returned with feedback: {f'{safe_feedback[:100]}...' if safe_feedback and len(safe_feedback) > 100 else safe_feedback if safe_feedback else 'None'}") except Exception as e: logger.error(f"[ToolInterceptor] Error during interrupt: {str(e)}") raise - logger.debug(f"[ToolInterceptor] Processing feedback approval for '{tool_name}'") + logger.debug(f"[ToolInterceptor] Processing feedback approval for '{safe_tool_name_local}'") # Check if user approved is_approved = ToolInterceptor._parse_approval(feedback) - logger.info(f"[ToolInterceptor] Tool '{tool_name}' approval decision: {is_approved}") + logger.info(f"[ToolInterceptor] Tool '{safe_tool_name_local}' approval decision: {is_approved}") if not is_approved: - logger.warning(f"[ToolInterceptor] User rejected execution of tool '{tool_name}'") + logger.warning(f"[ToolInterceptor] User rejected execution of tool '{safe_tool_name_local}'") return { "error": f"Tool execution rejected by user", "tool": tool_name, "status": "rejected", } - logger.info(f"[ToolInterceptor] User approved execution of tool '{tool_name}', proceeding") + logger.info(f"[ToolInterceptor] User approved execution of tool '{safe_tool_name_local}', proceeding") # Execute the original tool try: - logger.debug(f"[ToolInterceptor] Calling original function for tool '{tool_name}'") + logger.debug(f"[ToolInterceptor] Calling original function for tool '{safe_tool_name_local}'") result = original_func(*args, **kwargs) - logger.info(f"[ToolInterceptor] Tool '{tool_name}' execution completed successfully") - logger.debug(f"[ToolInterceptor] Tool result length: {len(str(result))}") + logger.info(f"[ToolInterceptor] Tool '{safe_tool_name_local}' execution completed successfully") + result_len = len(str(result)) + logger.debug(f"[ToolInterceptor] Tool result length: {result_len}") return result except Exception as e: - logger.error(f"[ToolInterceptor] Error executing tool '{tool_name}': {str(e)}") + logger.error(f"[ToolInterceptor] Error executing tool '{safe_tool_name_local}': {str(e)}") raise # Replace the function and update the tool # Use object.__setattr__ to bypass Pydantic validation - logger.debug(f"Attaching intercepted function to tool '{tool.name}'") + logger.debug(f"Attaching intercepted function to tool '{safe_tool_name}'") object.__setattr__(tool, "func", intercepted_func) return tool diff --git a/src/crawler/readability_extractor.py b/src/crawler/readability_extractor.py index 87b3b97..698d5b6 100644 --- a/src/crawler/readability_extractor.py +++ b/src/crawler/readability_extractor.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import logging + from readabilipy import simple_json_from_html_string from .article import Article diff --git a/src/server/app.py b/src/server/app.py index 33bf7a8..9c3e545 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -55,6 +55,13 @@ from src.server.rag_request import ( ) from src.tools import VolcengineTTS from src.utils.json_utils import sanitize_args +from src.utils.log_sanitizer import ( + sanitize_agent_name, + sanitize_log_input, + sanitize_thread_id, + sanitize_tool_name, + sanitize_user_content, +) logger = logging.getLogger(__name__) @@ -333,9 +340,13 @@ def _process_initial_messages(message, thread_id): async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent): """Process a single message chunk and yield appropriate events.""" + agent_name = _get_agent_name(agent, message_metadata) - logger.debug(f"[{thread_id}] _process_message_chunk started for agent_name={agent_name}") - logger.debug(f"[{thread_id}] Extracted agent_name: {agent_name}") + safe_agent_name = sanitize_agent_name(agent_name) + safe_thread_id = sanitize_thread_id(thread_id) + safe_agent = sanitize_agent_name(agent) + logger.debug(f"[{safe_thread_id}] _process_message_chunk started for agent={safe_agent_name}") + logger.debug(f"[{safe_thread_id}] Extracted agent_name: {safe_agent_name}") event_stream_message = _create_event_stream_message( message_chunk, message_metadata, thread_id, agent_name @@ -343,25 +354,29 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age if isinstance(message_chunk, ToolMessage): # Tool Message - Return the result of the tool call - logger.debug(f"[{thread_id}] Processing ToolMessage") + logger.debug(f"[{safe_thread_id}] Processing ToolMessage") tool_call_id = message_chunk.tool_call_id event_stream_message["tool_call_id"] = tool_call_id # Validate tool_call_id for debugging if tool_call_id: - logger.debug(f"[{thread_id}] ToolMessage with tool_call_id: {tool_call_id}") + safe_tool_id = sanitize_log_input(tool_call_id, max_length=100) + logger.debug(f"[{safe_thread_id}] ToolMessage with tool_call_id: {safe_tool_id}") else: - logger.warning(f"[{thread_id}] ToolMessage received without tool_call_id") + logger.warning(f"[{safe_thread_id}] ToolMessage received without tool_call_id") - logger.debug(f"[{thread_id}] Yielding tool_call_result event") + logger.debug(f"[{safe_thread_id}] Yielding tool_call_result event") yield _make_event("tool_call_result", event_stream_message) elif isinstance(message_chunk, AIMessageChunk): # AI Message - Raw message tokens - logger.debug(f"[{thread_id}] Processing AIMessageChunk, tool_calls={bool(message_chunk.tool_calls)}, tool_call_chunks={bool(message_chunk.tool_call_chunks)}") + has_tool_calls = bool(message_chunk.tool_calls) + has_chunks = bool(message_chunk.tool_call_chunks) + logger.debug(f"[{safe_thread_id}] Processing AIMessageChunk, tool_calls={has_tool_calls}, tool_call_chunks={has_chunks}") if message_chunk.tool_calls: # AI Message - Tool Call (complete tool calls) - logger.debug(f"[{thread_id}] AIMessageChunk has complete tool_calls: {[tc.get('name', 'unknown') for tc in message_chunk.tool_calls]}") + safe_tool_names = [sanitize_tool_name(tc.get('name', 'unknown')) for tc in message_chunk.tool_calls] + logger.debug(f"[{safe_thread_id}] AIMessageChunk has complete tool_calls: {safe_tool_names}") event_stream_message["tool_calls"] = message_chunk.tool_calls # Process tool_call_chunks with proper index-based grouping @@ -370,16 +385,18 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age ) if processed_chunks: event_stream_message["tool_call_chunks"] = processed_chunks + safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks] logger.debug( - f"[{thread_id}] Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, " + f"[{safe_thread_id}] Tool calls: {safe_tool_names}, " f"Processed chunks: {len(processed_chunks)}" ) - logger.debug(f"[{thread_id}] Yielding tool_calls event") + logger.debug(f"[{safe_thread_id}] Yielding tool_calls event") yield _make_event("tool_calls", event_stream_message) elif message_chunk.tool_call_chunks: # AI Message - Tool Call Chunks (streaming) - logger.debug(f"[{thread_id}] AIMessageChunk has streaming tool_call_chunks: {len(message_chunk.tool_call_chunks)} chunks") + chunks_count = len(message_chunk.tool_call_chunks) + logger.debug(f"[{safe_thread_id}] AIMessageChunk has streaming tool_call_chunks: {chunks_count} chunks") processed_chunks = _process_tool_call_chunks( message_chunk.tool_call_chunks ) @@ -392,26 +409,30 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age # Log index transitions to detect tool call boundaries if prev_chunk is not None and current_index != prev_chunk.get("index"): + prev_name = sanitize_tool_name(prev_chunk.get('name')) + curr_name = sanitize_tool_name(chunk.get('name')) logger.debug( - f"[{thread_id}] Tool call boundary detected: " - f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> " - f"{current_index} ({chunk.get('name')})" + f"[{safe_thread_id}] Tool call boundary detected: " + f"index {prev_chunk.get('index')} ({prev_name}) -> " + f"{current_index} ({curr_name})" ) prev_chunk = chunk # Include all processed chunks in the event event_stream_message["tool_call_chunks"] = processed_chunks + safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks] logger.debug( - f"[{thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): " - f"{[c.get('name') for c in processed_chunks]}" + f"[{safe_thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): " + f"{safe_chunk_names}" ) - logger.debug(f"[{thread_id}] Yielding tool_call_chunks event") + logger.debug(f"[{safe_thread_id}] Yielding tool_call_chunks event") yield _make_event("tool_call_chunks", event_stream_message) else: # AI Message - Raw message tokens - logger.debug(f"[{thread_id}] AIMessageChunk is raw message tokens, content_len={len(message_chunk.content) if isinstance(message_chunk.content, str) else 'unknown'}") + content_len = len(message_chunk.content) if isinstance(message_chunk.content, str) else 0 + logger.debug(f"[{safe_thread_id}] AIMessageChunk is raw message tokens, content_len={content_len}") yield _make_event("message_chunk", event_stream_message) @@ -419,7 +440,8 @@ async def _stream_graph_events( graph_instance, workflow_input, workflow_config, thread_id ): """Stream events from the graph and process them.""" - logger.debug(f"[{thread_id}] Starting graph event stream with agent nodes") + safe_thread_id = sanitize_thread_id(thread_id) + logger.debug(f"[{safe_thread_id}] Starting graph event stream with agent nodes") try: event_count = 0 async for agent, _, event_data in graph_instance.astream( @@ -429,28 +451,31 @@ async def _stream_graph_events( subgraphs=True, ): event_count += 1 - logger.debug(f"[{thread_id}] Graph event #{event_count} received from agent: {agent}") + safe_agent = sanitize_agent_name(agent) + logger.debug(f"[{safe_thread_id}] Graph event #{event_count} received from agent: {safe_agent}") if isinstance(event_data, dict): if "__interrupt__" in event_data: logger.debug( - f"[{thread_id}] Processing interrupt event: " + f"[{safe_thread_id}] Processing interrupt event: " f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, " f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}" ) yield _create_interrupt_event(thread_id, event_data) - logger.debug(f"[{thread_id}] Dict event without interrupt, skipping") + logger.debug(f"[{safe_thread_id}] Dict event without interrupt, skipping") continue message_chunk, message_metadata = cast( tuple[BaseMessage, dict[str, Any]], event_data ) + safe_node = sanitize_agent_name(message_metadata.get('langgraph_node', 'unknown')) + safe_step = sanitize_log_input(message_metadata.get('langgraph_step', 'unknown')) logger.debug( - f"[{thread_id}] Processing message chunk: " + f"[{safe_thread_id}] Processing message chunk: " f"type={type(message_chunk).__name__}, " - f"node={message_metadata.get('langgraph_node', 'unknown')}, " - f"step={message_metadata.get('langgraph_step', 'unknown')}" + f"node={safe_node}, " + f"step={safe_step}" ) async for event in _process_message_chunk( @@ -458,9 +483,9 @@ async def _stream_graph_events( ): yield event - logger.debug(f"[{thread_id}] Graph event stream completed. Total events: {event_count}") + logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}") except Exception as e: - logger.exception(f"[{thread_id}] Error during graph execution") + logger.exception(f"[{safe_thread_id}] Error during graph execution") yield _make_event( "error", { @@ -488,34 +513,38 @@ async def _astream_workflow_generator( locale: str = "en-US", interrupt_before_tools: Optional[List[str]] = None, ): + safe_thread_id = sanitize_thread_id(thread_id) + safe_feedback = sanitize_log_input(interrupt_feedback) if interrupt_feedback else "" logger.debug( - f"[{thread_id}] _astream_workflow_generator starting: " + f"[{safe_thread_id}] _astream_workflow_generator starting: " f"messages_count={len(messages)}, " f"auto_accepted_plan={auto_accepted_plan}, " - f"interrupt_feedback={interrupt_feedback}, " + f"interrupt_feedback={safe_feedback}, " f"interrupt_before_tools={interrupt_before_tools}" ) # Process initial messages - logger.debug(f"[{thread_id}] Processing {len(messages)} initial messages") + logger.debug(f"[{safe_thread_id}] Processing {len(messages)} initial messages") for message in messages: if isinstance(message, dict) and "content" in message: - logger.debug(f"[{thread_id}] Sending initial message to client: {message.get('content', '')[:100]}") + safe_content = sanitize_user_content(message.get('content', '')) + logger.debug(f"[{safe_thread_id}] Sending initial message to client: {safe_content}") _process_initial_messages(message, thread_id) - logger.debug(f"[{thread_id}] Reconstructing clarification history") + logger.debug(f"[{safe_thread_id}] Reconstructing clarification history") clarification_history = reconstruct_clarification_history(messages) - logger.debug(f"[{thread_id}] Building clarified topic from history") + logger.debug(f"[{safe_thread_id}] Building clarified topic from history") clarified_topic, clarification_history = build_clarified_topic_from_history( clarification_history ) latest_message_content = messages[-1]["content"] if messages else "" clarified_research_topic = clarified_topic or latest_message_content - logger.debug(f"[{thread_id}] Clarified research topic: {clarified_research_topic[:100]}") + safe_topic = sanitize_user_content(clarified_research_topic) + logger.debug(f"[{safe_thread_id}] Clarified research topic: {safe_topic}") # Prepare workflow input - logger.debug(f"[{thread_id}] Preparing workflow input") + logger.debug(f"[{safe_thread_id}] Preparing workflow input") workflow_input = { "messages": messages, "plan_iterations": 0, @@ -533,7 +562,7 @@ async def _astream_workflow_generator( } if not auto_accepted_plan and interrupt_feedback: - logger.debug(f"[{thread_id}] Creating resume command with interrupt_feedback: {interrupt_feedback}") + logger.debug(f"[{safe_thread_id}] Creating resume command with interrupt_feedback: {safe_feedback}") resume_msg = f"[{interrupt_feedback}]" if messages: resume_msg += f" {messages[-1]['content']}" @@ -541,7 +570,7 @@ async def _astream_workflow_generator( # Prepare workflow config logger.debug( - f"[{thread_id}] Preparing workflow config: " + f"[{safe_thread_id}] Preparing workflow config: " f"max_plan_iterations={max_plan_iterations}, " f"max_step_num={max_step_num}, " f"report_style={report_style.value}, " @@ -564,7 +593,7 @@ async def _astream_workflow_generator( checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "") logger.debug( - f"[{thread_id}] Checkpoint configuration: " + f"[{safe_thread_id}] Checkpoint configuration: " f"saver_enabled={checkpoint_saver}, " f"url_configured={bool(checkpoint_url)}" ) @@ -577,48 +606,48 @@ async def _astream_workflow_generator( } if checkpoint_saver and checkpoint_url != "": if checkpoint_url.startswith("postgresql://"): - logger.info(f"[{thread_id}] Starting async postgres checkpointer") - logger.debug(f"[{thread_id}] Setting up PostgreSQL connection pool") + logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer") + logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool") async with AsyncConnectionPool( checkpoint_url, kwargs=connection_kwargs ) as conn: - logger.debug(f"[{thread_id}] Initializing AsyncPostgresSaver") + logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver") checkpointer = AsyncPostgresSaver(conn) await checkpointer.setup() - logger.debug(f"[{thread_id}] Attaching checkpointer to graph") + logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph") graph.checkpointer = checkpointer graph.store = in_memory_store - logger.debug(f"[{thread_id}] Starting to stream graph events") + logger.debug(f"[{safe_thread_id}] Starting to stream graph events") async for event in _stream_graph_events( graph, workflow_input, workflow_config, thread_id ): yield event - logger.debug(f"[{thread_id}] Graph event streaming completed") + logger.debug(f"[{safe_thread_id}] Graph event streaming completed") if checkpoint_url.startswith("mongodb://"): - logger.info(f"[{thread_id}] Starting async mongodb checkpointer") - logger.debug(f"[{thread_id}] Setting up MongoDB connection") + logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer") + logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection") async with AsyncMongoDBSaver.from_conn_string( checkpoint_url ) as checkpointer: - logger.debug(f"[{thread_id}] Attaching MongoDB checkpointer to graph") + logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph") graph.checkpointer = checkpointer graph.store = in_memory_store - logger.debug(f"[{thread_id}] Starting to stream graph events") + logger.debug(f"[{safe_thread_id}] Starting to stream graph events") async for event in _stream_graph_events( graph, workflow_input, workflow_config, thread_id ): yield event - logger.debug(f"[{thread_id}] Graph event streaming completed") + logger.debug(f"[{safe_thread_id}] Graph event streaming completed") else: - logger.debug(f"[{thread_id}] No checkpointer configured, using in-memory graph") + logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph") # Use graph without MongoDB checkpointer - logger.debug(f"[{thread_id}] Starting to stream graph events") + logger.debug(f"[{safe_thread_id}] Starting to stream graph events") async for event in _stream_graph_events( graph, workflow_input, workflow_config, thread_id ): yield event - logger.debug(f"[{thread_id}] Graph event streaming completed") + logger.debug(f"[{safe_thread_id}] Graph event streaming completed") def _make_event(event_type: str, data: dict[str, any]): diff --git a/src/utils/log_sanitizer.py b/src/utils/log_sanitizer.py new file mode 100644 index 0000000..b6f45fe --- /dev/null +++ b/src/utils/log_sanitizer.py @@ -0,0 +1,186 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Log sanitization utilities to prevent log injection attacks. + +This module provides functions to sanitize user-controlled input before +logging to prevent attackers from forging log entries through: +- Newline injection (\n) +- HTML injection (for HTML logs) +- Special character sequences that could be misinterpreted +""" + +import re +from typing import Any, Optional + + +def sanitize_log_input(value: Any, max_length: int = 500) -> str: + """ + Sanitize user-controlled input for safe logging. + + Replaces dangerous characters (newlines, tabs, carriage returns, etc.) + with their escaped representations to prevent log injection attacks. + + Args: + value: The input value to sanitize (any type) + max_length: Maximum length of output string (truncates if exceeded) + + Returns: + str: Sanitized string safe for logging + + Examples: + >>> sanitize_log_input("normal text") + 'normal text' + + >>> sanitize_log_input("malicious\n[INFO] fake entry") + 'malicious\\n[INFO] fake entry' + + >>> sanitize_log_input("tab\there") + 'tab\\there' + + >>> sanitize_log_input(None) + 'None' + + >>> long_text = "a" * 1000 + >>> result = sanitize_log_input(long_text, max_length=100) + >>> len(result) <= 100 + True + """ + if value is None: + return "None" + + # Convert to string + string_value = str(value) + + # Replace dangerous characters with their escaped representations + # Order matters: escape backslashes first to avoid double-escaping + replacements = { + "\\": "\\\\", # Backslash (must be first) + "\n": "\\n", # Newline - prevents creating new log entries + "\r": "\\r", # Carriage return + "\t": "\\t", # Tab + "\x00": "\\0", # Null character + "\x1b": "\\x1b", # Escape character (used in ANSI sequences) + } + + for char, replacement in replacements.items(): + string_value = string_value.replace(char, replacement) + + # Remove other control characters (ASCII 0-31 except those already handled) + # These are rarely useful in logs and could be exploited + string_value = re.sub(r"[\x00-\x08\x0b-\x0c\x0e-\x1f]", "", string_value) + + # Truncate if too long (prevent log flooding) + if len(string_value) > max_length: + string_value = string_value[: max_length - 3] + "..." + + return string_value + + +def sanitize_thread_id(thread_id: Any) -> str: + """ + Sanitize thread_id for logging. + + Thread IDs should be alphanumeric with hyphens and underscores, + but we sanitize to be defensive. + + Args: + thread_id: The thread ID to sanitize + + Returns: + str: Sanitized thread ID + """ + return sanitize_log_input(thread_id, max_length=100) + + +def sanitize_user_content(content: Any) -> str: + """ + Sanitize user-provided message content for logging. + + User messages can be arbitrary length, so we truncate more aggressively. + + Args: + content: The user content to sanitize + + Returns: + str: Sanitized user content + """ + return sanitize_log_input(content, max_length=200) + + +def sanitize_agent_name(agent_name: Any) -> str: + """ + Sanitize agent name for logging. + + Agent names should be simple identifiers, but we sanitize to be defensive. + + Args: + agent_name: The agent name to sanitize + + Returns: + str: Sanitized agent name + """ + return sanitize_log_input(agent_name, max_length=100) + + +def sanitize_tool_name(tool_name: Any) -> str: + """ + Sanitize tool name for logging. + + Tool names should be simple identifiers, but we sanitize to be defensive. + + Args: + tool_name: The tool name to sanitize + + Returns: + str: Sanitized tool name + """ + return sanitize_log_input(tool_name, max_length=100) + + +def sanitize_feedback(feedback: Any) -> str: + """ + Sanitize user feedback for logging. + + Feedback can be arbitrary text from interrupts, so sanitize carefully. + + Args: + feedback: The feedback to sanitize + + Returns: + str: Sanitized feedback (truncated more aggressively) + """ + return sanitize_log_input(feedback, max_length=150) + + +def create_safe_log_message(template: str, **kwargs) -> str: + """ + Create a safe log message by sanitizing all values. + + Uses a template string with keyword arguments, sanitizing each value + before substitution to prevent log injection. + + Args: + template: Template string with {key} placeholders + **kwargs: Key-value pairs to substitute + + Returns: + str: Safe log message + + Example: + >>> msg = create_safe_log_message( + ... "[{thread_id}] Processing {tool_name}", + ... thread_id="abc\\n[INFO]", + ... tool_name="my_tool" + ... ) + >>> "[abc\\\\n[INFO]] Processing my_tool" in msg + True + """ + # Sanitize all values + safe_kwargs = { + key: sanitize_log_input(value) for key, value in kwargs.items() + } + + # Substitute into template + return template.format(**safe_kwargs) diff --git a/tests/integration/test_tool_interceptor_integration.py b/tests/integration/test_tool_interceptor_integration.py index a28fee4..20b017e 100644 --- a/tests/integration/test_tool_interceptor_integration.py +++ b/tests/integration/test_tool_interceptor_integration.py @@ -11,12 +11,12 @@ Tests the complete flow of selective tool interrupts including: - Resume mechanism after interrupt """ -import pytest -from unittest.mock import Mock, patch, AsyncMock, MagicMock, call from typing import Any +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch -from langchain_core.tools import tool +import pytest from langchain_core.messages import HumanMessage +from langchain_core.tools import tool from src.agents.agents import create_agent from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor diff --git a/tests/unit/agents/test_tool_interceptor.py b/tests/unit/agents/test_tool_interceptor.py index 8fa37df..5be1cd7 100644 --- a/tests/unit/agents/test_tool_interceptor.py +++ b/tests/unit/agents/test_tool_interceptor.py @@ -1,9 +1,10 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +from unittest.mock import AsyncMock, MagicMock, Mock, patch + import pytest -from unittest.mock import Mock, patch, MagicMock, AsyncMock -from langchain_core.tools import tool, BaseTool +from langchain_core.tools import BaseTool, tool from src.agents.tool_interceptor import ( ToolInterceptor, diff --git a/tests/unit/crawler/test_jina_client.py b/tests/unit/crawler/test_jina_client.py index 94edade..087b5d2 100644 --- a/tests/unit/crawler/test_jina_client.py +++ b/tests/unit/crawler/test_jina_client.py @@ -1,8 +1,10 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +from unittest.mock import Mock, patch + import pytest -from unittest.mock import patch, Mock + from src.crawler.jina_client import JinaClient diff --git a/tests/unit/crawler/test_readability_extractor.py b/tests/unit/crawler/test_readability_extractor.py index 0e375fa..e4226e9 100644 --- a/tests/unit/crawler/test_readability_extractor.py +++ b/tests/unit/crawler/test_readability_extractor.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT from unittest.mock import patch + from src.crawler.readability_extractor import ReadabilityExtractor diff --git a/tests/unit/graph/test_plan_validation.py b/tests/unit/graph/test_plan_validation.py index 64d6ac7..79fe3e7 100644 --- a/tests/unit/graph/test_plan_validation.py +++ b/tests/unit/graph/test_plan_validation.py @@ -1,8 +1,9 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from src.graph.nodes import validate_and_fix_plan diff --git a/tests/unit/server/test_tool_call_chunks.py b/tests/unit/server/test_tool_call_chunks.py index f4e017a..cd73360 100644 --- a/tests/unit/server/test_tool_call_chunks.py +++ b/tests/unit/server/test_tool_call_chunks.py @@ -10,13 +10,14 @@ tool names from being concatenated when multiple tool calls happen in sequence. """ import logging -import pytest -from unittest.mock import patch, MagicMock +import os # Import the functions to test # Note: We need to import from the app module import sys -import os +from unittest.mock import MagicMock, patch + +import pytest # Add src directory to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../")) diff --git a/tests/unit/utils/test_json_utils.py b/tests/unit/utils/test_json_utils.py index 0cc9795..e9ead1a 100644 --- a/tests/unit/utils/test_json_utils.py +++ b/tests/unit/utils/test_json_utils.py @@ -3,7 +3,11 @@ import json -from src.utils.json_utils import repair_json_output, sanitize_tool_response, _extract_json_from_content +from src.utils.json_utils import ( + _extract_json_from_content, + repair_json_output, + sanitize_tool_response, +) class TestRepairJsonOutput: diff --git a/tests/unit/utils/test_log_sanitizer.py b/tests/unit/utils/test_log_sanitizer.py new file mode 100644 index 0000000..ffd989c --- /dev/null +++ b/tests/unit/utils/test_log_sanitizer.py @@ -0,0 +1,268 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Unit tests for log sanitization utilities. + +This test file verifies that the log sanitizer properly prevents log injection attacks +by escaping dangerous characters in user-controlled input before logging. +""" + +import pytest + +from src.utils.log_sanitizer import ( + create_safe_log_message, + sanitize_agent_name, + sanitize_feedback, + sanitize_log_input, + sanitize_thread_id, + sanitize_tool_name, + sanitize_user_content, +) + + +class TestSanitizeLogInput: + """Test the main sanitize_log_input function.""" + + def test_sanitize_normal_text(self): + """Test that normal text is preserved.""" + text = "normal text" + result = sanitize_log_input(text) + assert result == "normal text" + + def test_sanitize_newline_injection(self): + """Test prevention of newline injection attack.""" + malicious = "abc\n[INFO] Forged log entry" + result = sanitize_log_input(malicious) + assert "\n" not in result + assert "[INFO]" in result # The attack text is preserved but escaped + assert "\\n" in result # Newline is escaped + + def test_sanitize_carriage_return(self): + """Test prevention of carriage return injection.""" + malicious = "text\r[WARN] Forged entry" + result = sanitize_log_input(malicious) + assert "\r" not in result + assert "\\r" in result + + def test_sanitize_tab_character(self): + """Test prevention of tab character injection.""" + malicious = "text\t[ERROR] Forged" + result = sanitize_log_input(malicious) + assert "\t" not in result + assert "\\t" in result + + def test_sanitize_null_character(self): + """Test prevention of null character injection.""" + malicious = "text\x00[CRITICAL]" + result = sanitize_log_input(malicious) + assert "\x00" not in result + + def test_sanitize_backslash(self): + """Test that backslashes are properly escaped.""" + text = "path\\to\\file" + result = sanitize_log_input(text) + assert result == "path\\\\to\\\\file" + + def test_sanitize_escape_character(self): + """Test prevention of ANSI escape sequence injection.""" + malicious = "text\x1b[31mRED TEXT\x1b[0m" + result = sanitize_log_input(malicious) + assert "\x1b" not in result + assert "\\x1b" in result + + def test_sanitize_max_length_truncation(self): + """Test that long strings are truncated.""" + long_text = "a" * 1000 + result = sanitize_log_input(long_text, max_length=100) + assert len(result) <= 100 + assert result.endswith("...") + + def test_sanitize_none_value(self): + """Test that None is handled properly.""" + result = sanitize_log_input(None) + assert result == "None" + + def test_sanitize_numeric_value(self): + """Test that numeric values are converted to strings.""" + result = sanitize_log_input(12345) + assert result == "12345" + + def test_sanitize_complex_injection_attack(self): + """Test complex multi-character injection attack.""" + malicious = 'thread-123\n[WARNING] Unauthorized\r[ERROR] System failure\t[CRITICAL] Shutdown' + result = sanitize_log_input(malicious) + # All dangerous characters should be escaped + assert "\n" not in result + assert "\r" not in result + assert "\t" not in result + # But the text should still be there (escaped) + assert "WARNING" in result + assert "ERROR" in result + + +class TestSanitizeThreadId: + """Test sanitization of thread IDs.""" + + def test_thread_id_normal(self): + """Test normal thread ID.""" + thread_id = "thread-123-abc" + result = sanitize_thread_id(thread_id) + assert result == "thread-123-abc" + + def test_thread_id_with_newline(self): + """Test thread ID with newline injection.""" + malicious = "thread-1\n[INFO] Forged" + result = sanitize_thread_id(malicious) + assert "\n" not in result + assert "\\n" in result + + def test_thread_id_max_length(self): + """Test that thread ID truncation respects max length.""" + long_id = "x" * 200 + result = sanitize_thread_id(long_id) + assert len(result) <= 100 + + +class TestSanitizeUserContent: + """Test sanitization of user-provided message content.""" + + def test_user_content_normal(self): + """Test normal user content.""" + content = "What is the weather today?" + result = sanitize_user_content(content) + assert result == "What is the weather today?" + + def test_user_content_with_newline(self): + """Test user content with newline.""" + malicious = "My question\n[ADMIN] Delete user" + result = sanitize_user_content(malicious) + assert "\n" not in result + assert "\\n" in result + + def test_user_content_max_length(self): + """Test that user content is truncated more aggressively.""" + long_content = "x" * 500 + result = sanitize_user_content(long_content) + assert len(result) <= 200 + + +class TestSanitizeToolName: + """Test sanitization of tool names.""" + + def test_tool_name_normal(self): + """Test normal tool name.""" + tool = "web_search" + result = sanitize_tool_name(tool) + assert result == "web_search" + + def test_tool_name_injection(self): + """Test tool name with injection attempt.""" + malicious = "search\n[WARN] Forged" + result = sanitize_tool_name(malicious) + assert "\n" not in result + + +class TestSanitizeFeedback: + """Test sanitization of user feedback.""" + + def test_feedback_normal(self): + """Test normal feedback.""" + feedback = "[accepted]" + result = sanitize_feedback(feedback) + assert result == "[accepted]" + + def test_feedback_injection(self): + """Test feedback with injection attempt.""" + malicious = "[approved]\n[CRITICAL] System down" + result = sanitize_feedback(malicious) + assert "\n" not in result + assert "\\n" in result + + def test_feedback_max_length(self): + """Test that feedback is truncated.""" + long_feedback = "x" * 500 + result = sanitize_feedback(long_feedback) + assert len(result) <= 150 + + +class TestCreateSafeLogMessage: + """Test the create_safe_log_message helper function.""" + + def test_safe_message_normal(self): + """Test normal message creation.""" + msg = create_safe_log_message( + "[{thread_id}] Processing {tool_name}", + thread_id="thread-1", + tool_name="search", + ) + assert "[thread-1] Processing search" == msg + + def test_safe_message_with_injection(self): + """Test message creation with injected values.""" + msg = create_safe_log_message( + "[{thread_id}] Tool: {tool_name}", + thread_id="id\n[INFO] Forged", + tool_name="search\r[ERROR]", + ) + # The dangerous characters should be escaped + assert "\n" not in msg + assert "\r" not in msg + assert "\\n" in msg + assert "\\r" in msg + + def test_safe_message_multiple_values(self): + """Test message with multiple values.""" + msg = create_safe_log_message( + "[{id}] User: {user} Tool: {tool}", + id="123", + user="admin\t[WARN]", + tool="delete\x1b[31m", + ) + assert "\t" not in msg + assert "\x1b" not in msg + + +class TestLogInjectionAttackPrevention: + """Integration tests for log injection prevention.""" + + def test_classic_log_injection_newline(self): + """Test the classic log injection attack using newlines.""" + attacker_input = 'abc\n[WARNING] Unauthorized access detected' + result = sanitize_log_input(attacker_input) + # The output should not contain an actual newline that would create a new log entry + assert result.count("\n") == 0 + # But the escaped version should be in there + assert "\\n" in result + + def test_carriage_return_log_injection(self): + """Test log injection via carriage return.""" + attacker_input = "request_id\r\n[ERROR] CRITICAL FAILURE" + result = sanitize_log_input(attacker_input) + assert "\r" not in result + assert "\n" not in result + + def test_html_injection_prevention(self): + """Test prevention of HTML injection in logs.""" + # While HTML tags themselves aren't dangerous in log files, + # escaping control characters helps prevent parsing attacks + malicious_html = "user\x1b[32m" + result = sanitize_log_input(malicious_html) + assert "\x1b" not in result + # HTML is preserved but with escaped control chars + assert "