security: add log injection attack prevention with input sanitization (#667)

* security: add log injection attack prevention with input sanitization

- Created src/utils/log_sanitizer.py to sanitize user-controlled input before logging
- Prevents log injection attacks using newlines, tabs, carriage returns, etc.
- Escapes dangerous characters: \n, \r, \t, \0, \x1b
- Provides specialized functions for different input types:
  - sanitize_log_input: general purpose sanitization
  - sanitize_thread_id: for user-provided thread IDs
  - sanitize_user_content: for user messages (more aggressive truncation)
  - sanitize_agent_name: for agent identifiers
  - sanitize_tool_name: for tool names
  - sanitize_feedback: for user interrupt feedback
  - create_safe_log_message: template-based safe message creation

- Updated src/server/app.py to sanitize all user input in logging:
  - Thread IDs from request parameter
  - Message content from user
  - Agent names and node information
  - Tool names and feedback

- Updated src/agents/tool_interceptor.py to sanitize:
  - Tool names during execution
  - User feedback during interrupt handling
  - Tool input data

- Added 29 comprehensive unit tests covering:
  - Classic newline injection attacks
  - Carriage return injection
  - Tab and null character injection
  - HTML/ANSI escape sequence injection
  - Combined multi-character attacks
  - Truncation and length limits

Fixes potential log forgery vulnerability where malicious users could inject
fake log entries via unsanitized input containing control characters.
This commit is contained in:
Willem Jiang
2025-10-27 20:57:23 +08:00
committed by GitHub
parent ccd7535072
commit b4c09aa4b1
13 changed files with 585 additions and 80 deletions

View File

@@ -6,10 +6,10 @@ from typing import List, Optional
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from src.agents.tool_interceptor import wrap_tools_with_interceptor
from src.config.agents import AGENT_LLM_MAP from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type from src.llms.llm import get_llm_by_type
from src.prompts import apply_prompt_template from src.prompts import apply_prompt_template
from src.agents.tool_interceptor import wrap_tools_with_interceptor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -8,6 +8,12 @@ from typing import Any, Callable, List, Optional
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.types import interrupt from langgraph.types import interrupt
from src.utils.log_sanitizer import (
sanitize_feedback,
sanitize_log_input,
sanitize_tool_name,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -84,27 +90,30 @@ class ToolInterceptor:
BaseTool: The wrapped tool with interrupt capability BaseTool: The wrapped tool with interrupt capability
""" """
original_func = tool.func original_func = tool.func
logger.debug(f"Wrapping tool '{tool.name}' with interrupt capability") safe_tool_name = sanitize_tool_name(tool.name)
logger.debug(f"Wrapping tool '{safe_tool_name}' with interrupt capability")
def intercepted_func(*args: Any, **kwargs: Any) -> Any: def intercepted_func(*args: Any, **kwargs: Any) -> Any:
"""Execute the tool with interrupt check.""" """Execute the tool with interrupt check."""
tool_name = tool.name tool_name = tool.name
logger.debug(f"[ToolInterceptor] Executing tool: {tool_name}") safe_tool_name_local = sanitize_tool_name(tool_name)
logger.debug(f"[ToolInterceptor] Executing tool: {safe_tool_name_local}")
# Format tool input for display # Format tool input for display
tool_input = args[0] if args else kwargs tool_input = args[0] if args else kwargs
tool_input_repr = ToolInterceptor._format_tool_input(tool_input) tool_input_repr = ToolInterceptor._format_tool_input(tool_input)
logger.debug(f"[ToolInterceptor] Tool input: {tool_input_repr[:200]}") safe_tool_input = sanitize_log_input(tool_input_repr, max_length=100)
logger.debug(f"[ToolInterceptor] Tool input: {safe_tool_input}")
should_interrupt = interceptor.should_interrupt(tool_name) should_interrupt = interceptor.should_interrupt(tool_name)
logger.debug(f"[ToolInterceptor] should_interrupt={should_interrupt} for tool '{tool_name}'") logger.debug(f"[ToolInterceptor] should_interrupt={should_interrupt} for tool '{safe_tool_name_local}'")
if should_interrupt: if should_interrupt:
logger.info( logger.info(
f"[ToolInterceptor] Interrupting before tool '{tool_name}'" f"[ToolInterceptor] Interrupting before tool '{safe_tool_name_local}'"
) )
logger.debug( logger.debug(
f"[ToolInterceptor] Interrupt message: About to execute tool '{tool_name}' with input: {tool_input_repr[:100]}..." f"[ToolInterceptor] Interrupt message: About to execute tool '{safe_tool_name_local}' with input: {safe_tool_input}..."
) )
# Trigger interrupt and wait for user feedback # Trigger interrupt and wait for user feedback
@@ -112,41 +121,43 @@ class ToolInterceptor:
feedback = interrupt( feedback = interrupt(
f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?" f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?"
) )
logger.debug(f"[ToolInterceptor] Interrupt returned with feedback: {f'{feedback[:100]}...' if feedback and len(feedback) > 100 else feedback if feedback else 'None'}") safe_feedback = sanitize_feedback(feedback)
logger.debug(f"[ToolInterceptor] Interrupt returned with feedback: {f'{safe_feedback[:100]}...' if safe_feedback and len(safe_feedback) > 100 else safe_feedback if safe_feedback else 'None'}")
except Exception as e: except Exception as e:
logger.error(f"[ToolInterceptor] Error during interrupt: {str(e)}") logger.error(f"[ToolInterceptor] Error during interrupt: {str(e)}")
raise raise
logger.debug(f"[ToolInterceptor] Processing feedback approval for '{tool_name}'") logger.debug(f"[ToolInterceptor] Processing feedback approval for '{safe_tool_name_local}'")
# Check if user approved # Check if user approved
is_approved = ToolInterceptor._parse_approval(feedback) is_approved = ToolInterceptor._parse_approval(feedback)
logger.info(f"[ToolInterceptor] Tool '{tool_name}' approval decision: {is_approved}") logger.info(f"[ToolInterceptor] Tool '{safe_tool_name_local}' approval decision: {is_approved}")
if not is_approved: if not is_approved:
logger.warning(f"[ToolInterceptor] User rejected execution of tool '{tool_name}'") logger.warning(f"[ToolInterceptor] User rejected execution of tool '{safe_tool_name_local}'")
return { return {
"error": f"Tool execution rejected by user", "error": f"Tool execution rejected by user",
"tool": tool_name, "tool": tool_name,
"status": "rejected", "status": "rejected",
} }
logger.info(f"[ToolInterceptor] User approved execution of tool '{tool_name}', proceeding") logger.info(f"[ToolInterceptor] User approved execution of tool '{safe_tool_name_local}', proceeding")
# Execute the original tool # Execute the original tool
try: try:
logger.debug(f"[ToolInterceptor] Calling original function for tool '{tool_name}'") logger.debug(f"[ToolInterceptor] Calling original function for tool '{safe_tool_name_local}'")
result = original_func(*args, **kwargs) result = original_func(*args, **kwargs)
logger.info(f"[ToolInterceptor] Tool '{tool_name}' execution completed successfully") logger.info(f"[ToolInterceptor] Tool '{safe_tool_name_local}' execution completed successfully")
logger.debug(f"[ToolInterceptor] Tool result length: {len(str(result))}") result_len = len(str(result))
logger.debug(f"[ToolInterceptor] Tool result length: {result_len}")
return result return result
except Exception as e: except Exception as e:
logger.error(f"[ToolInterceptor] Error executing tool '{tool_name}': {str(e)}") logger.error(f"[ToolInterceptor] Error executing tool '{safe_tool_name_local}': {str(e)}")
raise raise
# Replace the function and update the tool # Replace the function and update the tool
# Use object.__setattr__ to bypass Pydantic validation # Use object.__setattr__ to bypass Pydantic validation
logger.debug(f"Attaching intercepted function to tool '{tool.name}'") logger.debug(f"Attaching intercepted function to tool '{safe_tool_name}'")
object.__setattr__(tool, "func", intercepted_func) object.__setattr__(tool, "func", intercepted_func)
return tool return tool

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import logging import logging
from readabilipy import simple_json_from_html_string from readabilipy import simple_json_from_html_string
from .article import Article from .article import Article

View File

@@ -55,6 +55,13 @@ from src.server.rag_request import (
) )
from src.tools import VolcengineTTS from src.tools import VolcengineTTS
from src.utils.json_utils import sanitize_args from src.utils.json_utils import sanitize_args
from src.utils.log_sanitizer import (
sanitize_agent_name,
sanitize_log_input,
sanitize_thread_id,
sanitize_tool_name,
sanitize_user_content,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -333,9 +340,13 @@ def _process_initial_messages(message, thread_id):
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent): async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
"""Process a single message chunk and yield appropriate events.""" """Process a single message chunk and yield appropriate events."""
agent_name = _get_agent_name(agent, message_metadata) agent_name = _get_agent_name(agent, message_metadata)
logger.debug(f"[{thread_id}] _process_message_chunk started for agent_name={agent_name}") safe_agent_name = sanitize_agent_name(agent_name)
logger.debug(f"[{thread_id}] Extracted agent_name: {agent_name}") safe_thread_id = sanitize_thread_id(thread_id)
safe_agent = sanitize_agent_name(agent)
logger.debug(f"[{safe_thread_id}] _process_message_chunk started for agent={safe_agent_name}")
logger.debug(f"[{safe_thread_id}] Extracted agent_name: {safe_agent_name}")
event_stream_message = _create_event_stream_message( event_stream_message = _create_event_stream_message(
message_chunk, message_metadata, thread_id, agent_name message_chunk, message_metadata, thread_id, agent_name
@@ -343,25 +354,29 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
if isinstance(message_chunk, ToolMessage): if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call # Tool Message - Return the result of the tool call
logger.debug(f"[{thread_id}] Processing ToolMessage") logger.debug(f"[{safe_thread_id}] Processing ToolMessage")
tool_call_id = message_chunk.tool_call_id tool_call_id = message_chunk.tool_call_id
event_stream_message["tool_call_id"] = tool_call_id event_stream_message["tool_call_id"] = tool_call_id
# Validate tool_call_id for debugging # Validate tool_call_id for debugging
if tool_call_id: if tool_call_id:
logger.debug(f"[{thread_id}] ToolMessage with tool_call_id: {tool_call_id}") safe_tool_id = sanitize_log_input(tool_call_id, max_length=100)
logger.debug(f"[{safe_thread_id}] ToolMessage with tool_call_id: {safe_tool_id}")
else: else:
logger.warning(f"[{thread_id}] ToolMessage received without tool_call_id") logger.warning(f"[{safe_thread_id}] ToolMessage received without tool_call_id")
logger.debug(f"[{thread_id}] Yielding tool_call_result event") logger.debug(f"[{safe_thread_id}] Yielding tool_call_result event")
yield _make_event("tool_call_result", event_stream_message) yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk): elif isinstance(message_chunk, AIMessageChunk):
# AI Message - Raw message tokens # AI Message - Raw message tokens
logger.debug(f"[{thread_id}] Processing AIMessageChunk, tool_calls={bool(message_chunk.tool_calls)}, tool_call_chunks={bool(message_chunk.tool_call_chunks)}") has_tool_calls = bool(message_chunk.tool_calls)
has_chunks = bool(message_chunk.tool_call_chunks)
logger.debug(f"[{safe_thread_id}] Processing AIMessageChunk, tool_calls={has_tool_calls}, tool_call_chunks={has_chunks}")
if message_chunk.tool_calls: if message_chunk.tool_calls:
# AI Message - Tool Call (complete tool calls) # AI Message - Tool Call (complete tool calls)
logger.debug(f"[{thread_id}] AIMessageChunk has complete tool_calls: {[tc.get('name', 'unknown') for tc in message_chunk.tool_calls]}") safe_tool_names = [sanitize_tool_name(tc.get('name', 'unknown')) for tc in message_chunk.tool_calls]
logger.debug(f"[{safe_thread_id}] AIMessageChunk has complete tool_calls: {safe_tool_names}")
event_stream_message["tool_calls"] = message_chunk.tool_calls event_stream_message["tool_calls"] = message_chunk.tool_calls
# Process tool_call_chunks with proper index-based grouping # Process tool_call_chunks with proper index-based grouping
@@ -370,16 +385,18 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
) )
if processed_chunks: if processed_chunks:
event_stream_message["tool_call_chunks"] = processed_chunks event_stream_message["tool_call_chunks"] = processed_chunks
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
logger.debug( logger.debug(
f"[{thread_id}] Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, " f"[{safe_thread_id}] Tool calls: {safe_tool_names}, "
f"Processed chunks: {len(processed_chunks)}" f"Processed chunks: {len(processed_chunks)}"
) )
logger.debug(f"[{thread_id}] Yielding tool_calls event") logger.debug(f"[{safe_thread_id}] Yielding tool_calls event")
yield _make_event("tool_calls", event_stream_message) yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks: elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks (streaming) # AI Message - Tool Call Chunks (streaming)
logger.debug(f"[{thread_id}] AIMessageChunk has streaming tool_call_chunks: {len(message_chunk.tool_call_chunks)} chunks") chunks_count = len(message_chunk.tool_call_chunks)
logger.debug(f"[{safe_thread_id}] AIMessageChunk has streaming tool_call_chunks: {chunks_count} chunks")
processed_chunks = _process_tool_call_chunks( processed_chunks = _process_tool_call_chunks(
message_chunk.tool_call_chunks message_chunk.tool_call_chunks
) )
@@ -392,26 +409,30 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
# Log index transitions to detect tool call boundaries # Log index transitions to detect tool call boundaries
if prev_chunk is not None and current_index != prev_chunk.get("index"): if prev_chunk is not None and current_index != prev_chunk.get("index"):
prev_name = sanitize_tool_name(prev_chunk.get('name'))
curr_name = sanitize_tool_name(chunk.get('name'))
logger.debug( logger.debug(
f"[{thread_id}] Tool call boundary detected: " f"[{safe_thread_id}] Tool call boundary detected: "
f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> " f"index {prev_chunk.get('index')} ({prev_name}) -> "
f"{current_index} ({chunk.get('name')})" f"{current_index} ({curr_name})"
) )
prev_chunk = chunk prev_chunk = chunk
# Include all processed chunks in the event # Include all processed chunks in the event
event_stream_message["tool_call_chunks"] = processed_chunks event_stream_message["tool_call_chunks"] = processed_chunks
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
logger.debug( logger.debug(
f"[{thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): " f"[{safe_thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
f"{[c.get('name') for c in processed_chunks]}" f"{safe_chunk_names}"
) )
logger.debug(f"[{thread_id}] Yielding tool_call_chunks event") logger.debug(f"[{safe_thread_id}] Yielding tool_call_chunks event")
yield _make_event("tool_call_chunks", event_stream_message) yield _make_event("tool_call_chunks", event_stream_message)
else: else:
# AI Message - Raw message tokens # AI Message - Raw message tokens
logger.debug(f"[{thread_id}] AIMessageChunk is raw message tokens, content_len={len(message_chunk.content) if isinstance(message_chunk.content, str) else 'unknown'}") content_len = len(message_chunk.content) if isinstance(message_chunk.content, str) else 0
logger.debug(f"[{safe_thread_id}] AIMessageChunk is raw message tokens, content_len={content_len}")
yield _make_event("message_chunk", event_stream_message) yield _make_event("message_chunk", event_stream_message)
@@ -419,7 +440,8 @@ async def _stream_graph_events(
graph_instance, workflow_input, workflow_config, thread_id graph_instance, workflow_input, workflow_config, thread_id
): ):
"""Stream events from the graph and process them.""" """Stream events from the graph and process them."""
logger.debug(f"[{thread_id}] Starting graph event stream with agent nodes") safe_thread_id = sanitize_thread_id(thread_id)
logger.debug(f"[{safe_thread_id}] Starting graph event stream with agent nodes")
try: try:
event_count = 0 event_count = 0
async for agent, _, event_data in graph_instance.astream( async for agent, _, event_data in graph_instance.astream(
@@ -429,28 +451,31 @@ async def _stream_graph_events(
subgraphs=True, subgraphs=True,
): ):
event_count += 1 event_count += 1
logger.debug(f"[{thread_id}] Graph event #{event_count} received from agent: {agent}") safe_agent = sanitize_agent_name(agent)
logger.debug(f"[{safe_thread_id}] Graph event #{event_count} received from agent: {safe_agent}")
if isinstance(event_data, dict): if isinstance(event_data, dict):
if "__interrupt__" in event_data: if "__interrupt__" in event_data:
logger.debug( logger.debug(
f"[{thread_id}] Processing interrupt event: " f"[{safe_thread_id}] Processing interrupt event: "
f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, " f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, "
f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}" f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}"
) )
yield _create_interrupt_event(thread_id, event_data) yield _create_interrupt_event(thread_id, event_data)
logger.debug(f"[{thread_id}] Dict event without interrupt, skipping") logger.debug(f"[{safe_thread_id}] Dict event without interrupt, skipping")
continue continue
message_chunk, message_metadata = cast( message_chunk, message_metadata = cast(
tuple[BaseMessage, dict[str, Any]], event_data tuple[BaseMessage, dict[str, Any]], event_data
) )
safe_node = sanitize_agent_name(message_metadata.get('langgraph_node', 'unknown'))
safe_step = sanitize_log_input(message_metadata.get('langgraph_step', 'unknown'))
logger.debug( logger.debug(
f"[{thread_id}] Processing message chunk: " f"[{safe_thread_id}] Processing message chunk: "
f"type={type(message_chunk).__name__}, " f"type={type(message_chunk).__name__}, "
f"node={message_metadata.get('langgraph_node', 'unknown')}, " f"node={safe_node}, "
f"step={message_metadata.get('langgraph_step', 'unknown')}" f"step={safe_step}"
) )
async for event in _process_message_chunk( async for event in _process_message_chunk(
@@ -458,9 +483,9 @@ async def _stream_graph_events(
): ):
yield event yield event
logger.debug(f"[{thread_id}] Graph event stream completed. Total events: {event_count}") logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
except Exception as e: except Exception as e:
logger.exception(f"[{thread_id}] Error during graph execution") logger.exception(f"[{safe_thread_id}] Error during graph execution")
yield _make_event( yield _make_event(
"error", "error",
{ {
@@ -488,34 +513,38 @@ async def _astream_workflow_generator(
locale: str = "en-US", locale: str = "en-US",
interrupt_before_tools: Optional[List[str]] = None, interrupt_before_tools: Optional[List[str]] = None,
): ):
safe_thread_id = sanitize_thread_id(thread_id)
safe_feedback = sanitize_log_input(interrupt_feedback) if interrupt_feedback else ""
logger.debug( logger.debug(
f"[{thread_id}] _astream_workflow_generator starting: " f"[{safe_thread_id}] _astream_workflow_generator starting: "
f"messages_count={len(messages)}, " f"messages_count={len(messages)}, "
f"auto_accepted_plan={auto_accepted_plan}, " f"auto_accepted_plan={auto_accepted_plan}, "
f"interrupt_feedback={interrupt_feedback}, " f"interrupt_feedback={safe_feedback}, "
f"interrupt_before_tools={interrupt_before_tools}" f"interrupt_before_tools={interrupt_before_tools}"
) )
# Process initial messages # Process initial messages
logger.debug(f"[{thread_id}] Processing {len(messages)} initial messages") logger.debug(f"[{safe_thread_id}] Processing {len(messages)} initial messages")
for message in messages: for message in messages:
if isinstance(message, dict) and "content" in message: if isinstance(message, dict) and "content" in message:
logger.debug(f"[{thread_id}] Sending initial message to client: {message.get('content', '')[:100]}") safe_content = sanitize_user_content(message.get('content', ''))
logger.debug(f"[{safe_thread_id}] Sending initial message to client: {safe_content}")
_process_initial_messages(message, thread_id) _process_initial_messages(message, thread_id)
logger.debug(f"[{thread_id}] Reconstructing clarification history") logger.debug(f"[{safe_thread_id}] Reconstructing clarification history")
clarification_history = reconstruct_clarification_history(messages) clarification_history = reconstruct_clarification_history(messages)
logger.debug(f"[{thread_id}] Building clarified topic from history") logger.debug(f"[{safe_thread_id}] Building clarified topic from history")
clarified_topic, clarification_history = build_clarified_topic_from_history( clarified_topic, clarification_history = build_clarified_topic_from_history(
clarification_history clarification_history
) )
latest_message_content = messages[-1]["content"] if messages else "" latest_message_content = messages[-1]["content"] if messages else ""
clarified_research_topic = clarified_topic or latest_message_content clarified_research_topic = clarified_topic or latest_message_content
logger.debug(f"[{thread_id}] Clarified research topic: {clarified_research_topic[:100]}") safe_topic = sanitize_user_content(clarified_research_topic)
logger.debug(f"[{safe_thread_id}] Clarified research topic: {safe_topic}")
# Prepare workflow input # Prepare workflow input
logger.debug(f"[{thread_id}] Preparing workflow input") logger.debug(f"[{safe_thread_id}] Preparing workflow input")
workflow_input = { workflow_input = {
"messages": messages, "messages": messages,
"plan_iterations": 0, "plan_iterations": 0,
@@ -533,7 +562,7 @@ async def _astream_workflow_generator(
} }
if not auto_accepted_plan and interrupt_feedback: if not auto_accepted_plan and interrupt_feedback:
logger.debug(f"[{thread_id}] Creating resume command with interrupt_feedback: {interrupt_feedback}") logger.debug(f"[{safe_thread_id}] Creating resume command with interrupt_feedback: {safe_feedback}")
resume_msg = f"[{interrupt_feedback}]" resume_msg = f"[{interrupt_feedback}]"
if messages: if messages:
resume_msg += f" {messages[-1]['content']}" resume_msg += f" {messages[-1]['content']}"
@@ -541,7 +570,7 @@ async def _astream_workflow_generator(
# Prepare workflow config # Prepare workflow config
logger.debug( logger.debug(
f"[{thread_id}] Preparing workflow config: " f"[{safe_thread_id}] Preparing workflow config: "
f"max_plan_iterations={max_plan_iterations}, " f"max_plan_iterations={max_plan_iterations}, "
f"max_step_num={max_step_num}, " f"max_step_num={max_step_num}, "
f"report_style={report_style.value}, " f"report_style={report_style.value}, "
@@ -564,7 +593,7 @@ async def _astream_workflow_generator(
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "") checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
logger.debug( logger.debug(
f"[{thread_id}] Checkpoint configuration: " f"[{safe_thread_id}] Checkpoint configuration: "
f"saver_enabled={checkpoint_saver}, " f"saver_enabled={checkpoint_saver}, "
f"url_configured={bool(checkpoint_url)}" f"url_configured={bool(checkpoint_url)}"
) )
@@ -577,48 +606,48 @@ async def _astream_workflow_generator(
} }
if checkpoint_saver and checkpoint_url != "": if checkpoint_saver and checkpoint_url != "":
if checkpoint_url.startswith("postgresql://"): if checkpoint_url.startswith("postgresql://"):
logger.info(f"[{thread_id}] Starting async postgres checkpointer") logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
logger.debug(f"[{thread_id}] Setting up PostgreSQL connection pool") logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
async with AsyncConnectionPool( async with AsyncConnectionPool(
checkpoint_url, kwargs=connection_kwargs checkpoint_url, kwargs=connection_kwargs
) as conn: ) as conn:
logger.debug(f"[{thread_id}] Initializing AsyncPostgresSaver") logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
checkpointer = AsyncPostgresSaver(conn) checkpointer = AsyncPostgresSaver(conn)
await checkpointer.setup() await checkpointer.setup()
logger.debug(f"[{thread_id}] Attaching checkpointer to graph") logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
graph.checkpointer = checkpointer graph.checkpointer = checkpointer
graph.store = in_memory_store graph.store = in_memory_store
logger.debug(f"[{thread_id}] Starting to stream graph events") logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
async for event in _stream_graph_events( async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id graph, workflow_input, workflow_config, thread_id
): ):
yield event yield event
logger.debug(f"[{thread_id}] Graph event streaming completed") logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
if checkpoint_url.startswith("mongodb://"): if checkpoint_url.startswith("mongodb://"):
logger.info(f"[{thread_id}] Starting async mongodb checkpointer") logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
logger.debug(f"[{thread_id}] Setting up MongoDB connection") logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
async with AsyncMongoDBSaver.from_conn_string( async with AsyncMongoDBSaver.from_conn_string(
checkpoint_url checkpoint_url
) as checkpointer: ) as checkpointer:
logger.debug(f"[{thread_id}] Attaching MongoDB checkpointer to graph") logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
graph.checkpointer = checkpointer graph.checkpointer = checkpointer
graph.store = in_memory_store graph.store = in_memory_store
logger.debug(f"[{thread_id}] Starting to stream graph events") logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
async for event in _stream_graph_events( async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id graph, workflow_input, workflow_config, thread_id
): ):
yield event yield event
logger.debug(f"[{thread_id}] Graph event streaming completed") logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
else: else:
logger.debug(f"[{thread_id}] No checkpointer configured, using in-memory graph") logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
# Use graph without MongoDB checkpointer # Use graph without MongoDB checkpointer
logger.debug(f"[{thread_id}] Starting to stream graph events") logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
async for event in _stream_graph_events( async for event in _stream_graph_events(
graph, workflow_input, workflow_config, thread_id graph, workflow_input, workflow_config, thread_id
): ):
yield event yield event
logger.debug(f"[{thread_id}] Graph event streaming completed") logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
def _make_event(event_type: str, data: dict[str, any]): def _make_event(event_type: str, data: dict[str, any]):

186
src/utils/log_sanitizer.py Normal file
View File

@@ -0,0 +1,186 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Log sanitization utilities to prevent log injection attacks.
This module provides functions to sanitize user-controlled input before
logging to prevent attackers from forging log entries through:
- Newline injection (\n)
- HTML injection (for HTML logs)
- Special character sequences that could be misinterpreted
"""
import re
from typing import Any, Optional
def sanitize_log_input(value: Any, max_length: int = 500) -> str:
"""
Sanitize user-controlled input for safe logging.
Replaces dangerous characters (newlines, tabs, carriage returns, etc.)
with their escaped representations to prevent log injection attacks.
Args:
value: The input value to sanitize (any type)
max_length: Maximum length of output string (truncates if exceeded)
Returns:
str: Sanitized string safe for logging
Examples:
>>> sanitize_log_input("normal text")
'normal text'
>>> sanitize_log_input("malicious\n[INFO] fake entry")
'malicious\\n[INFO] fake entry'
>>> sanitize_log_input("tab\there")
'tab\\there'
>>> sanitize_log_input(None)
'None'
>>> long_text = "a" * 1000
>>> result = sanitize_log_input(long_text, max_length=100)
>>> len(result) <= 100
True
"""
if value is None:
return "None"
# Convert to string
string_value = str(value)
# Replace dangerous characters with their escaped representations
# Order matters: escape backslashes first to avoid double-escaping
replacements = {
"\\": "\\\\", # Backslash (must be first)
"\n": "\\n", # Newline - prevents creating new log entries
"\r": "\\r", # Carriage return
"\t": "\\t", # Tab
"\x00": "\\0", # Null character
"\x1b": "\\x1b", # Escape character (used in ANSI sequences)
}
for char, replacement in replacements.items():
string_value = string_value.replace(char, replacement)
# Remove other control characters (ASCII 0-31 except those already handled)
# These are rarely useful in logs and could be exploited
string_value = re.sub(r"[\x00-\x08\x0b-\x0c\x0e-\x1f]", "", string_value)
# Truncate if too long (prevent log flooding)
if len(string_value) > max_length:
string_value = string_value[: max_length - 3] + "..."
return string_value
def sanitize_thread_id(thread_id: Any) -> str:
"""
Sanitize thread_id for logging.
Thread IDs should be alphanumeric with hyphens and underscores,
but we sanitize to be defensive.
Args:
thread_id: The thread ID to sanitize
Returns:
str: Sanitized thread ID
"""
return sanitize_log_input(thread_id, max_length=100)
def sanitize_user_content(content: Any) -> str:
"""
Sanitize user-provided message content for logging.
User messages can be arbitrary length, so we truncate more aggressively.
Args:
content: The user content to sanitize
Returns:
str: Sanitized user content
"""
return sanitize_log_input(content, max_length=200)
def sanitize_agent_name(agent_name: Any) -> str:
"""
Sanitize agent name for logging.
Agent names should be simple identifiers, but we sanitize to be defensive.
Args:
agent_name: The agent name to sanitize
Returns:
str: Sanitized agent name
"""
return sanitize_log_input(agent_name, max_length=100)
def sanitize_tool_name(tool_name: Any) -> str:
"""
Sanitize tool name for logging.
Tool names should be simple identifiers, but we sanitize to be defensive.
Args:
tool_name: The tool name to sanitize
Returns:
str: Sanitized tool name
"""
return sanitize_log_input(tool_name, max_length=100)
def sanitize_feedback(feedback: Any) -> str:
"""
Sanitize user feedback for logging.
Feedback can be arbitrary text from interrupts, so sanitize carefully.
Args:
feedback: The feedback to sanitize
Returns:
str: Sanitized feedback (truncated more aggressively)
"""
return sanitize_log_input(feedback, max_length=150)
def create_safe_log_message(template: str, **kwargs) -> str:
"""
Create a safe log message by sanitizing all values.
Uses a template string with keyword arguments, sanitizing each value
before substitution to prevent log injection.
Args:
template: Template string with {key} placeholders
**kwargs: Key-value pairs to substitute
Returns:
str: Safe log message
Example:
>>> msg = create_safe_log_message(
... "[{thread_id}] Processing {tool_name}",
... thread_id="abc\\n[INFO]",
... tool_name="my_tool"
... )
>>> "[abc\\\\n[INFO]] Processing my_tool" in msg
True
"""
# Sanitize all values
safe_kwargs = {
key: sanitize_log_input(value) for key, value in kwargs.items()
}
# Substitute into template
return template.format(**safe_kwargs)

View File

@@ -11,12 +11,12 @@ Tests the complete flow of selective tool interrupts including:
- Resume mechanism after interrupt - Resume mechanism after interrupt
""" """
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock, call
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
from langchain_core.tools import tool import pytest
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from src.agents.agents import create_agent from src.agents.agents import create_agent
from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor

View File

@@ -1,9 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
from unittest.mock import Mock, patch, MagicMock, AsyncMock from langchain_core.tools import BaseTool, tool
from langchain_core.tools import tool, BaseTool
from src.agents.tool_interceptor import ( from src.agents.tool_interceptor import (
ToolInterceptor, ToolInterceptor,

View File

@@ -1,8 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch
import pytest import pytest
from unittest.mock import patch, Mock
from src.crawler.jina_client import JinaClient from src.crawler.jina_client import JinaClient

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from unittest.mock import patch from unittest.mock import patch
from src.crawler.readability_extractor import ReadabilityExtractor from src.crawler.readability_extractor import ReadabilityExtractor

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from src.graph.nodes import validate_and_fix_plan from src.graph.nodes import validate_and_fix_plan

View File

@@ -10,13 +10,14 @@ tool names from being concatenated when multiple tool calls happen in sequence.
""" """
import logging import logging
import pytest import os
from unittest.mock import patch, MagicMock
# Import the functions to test # Import the functions to test
# Note: We need to import from the app module # Note: We need to import from the app module
import sys import sys
import os from unittest.mock import MagicMock, patch
import pytest
# Add src directory to path for imports # Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))

View File

@@ -3,7 +3,11 @@
import json import json
from src.utils.json_utils import repair_json_output, sanitize_tool_response, _extract_json_from_content from src.utils.json_utils import (
_extract_json_from_content,
repair_json_output,
sanitize_tool_response,
)
class TestRepairJsonOutput: class TestRepairJsonOutput:

View File

@@ -0,0 +1,268 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for log sanitization utilities.
This test file verifies that the log sanitizer properly prevents log injection attacks
by escaping dangerous characters in user-controlled input before logging.
"""
import pytest
from src.utils.log_sanitizer import (
create_safe_log_message,
sanitize_agent_name,
sanitize_feedback,
sanitize_log_input,
sanitize_thread_id,
sanitize_tool_name,
sanitize_user_content,
)
class TestSanitizeLogInput:
"""Test the main sanitize_log_input function."""
def test_sanitize_normal_text(self):
"""Test that normal text is preserved."""
text = "normal text"
result = sanitize_log_input(text)
assert result == "normal text"
def test_sanitize_newline_injection(self):
"""Test prevention of newline injection attack."""
malicious = "abc\n[INFO] Forged log entry"
result = sanitize_log_input(malicious)
assert "\n" not in result
assert "[INFO]" in result # The attack text is preserved but escaped
assert "\\n" in result # Newline is escaped
def test_sanitize_carriage_return(self):
"""Test prevention of carriage return injection."""
malicious = "text\r[WARN] Forged entry"
result = sanitize_log_input(malicious)
assert "\r" not in result
assert "\\r" in result
def test_sanitize_tab_character(self):
"""Test prevention of tab character injection."""
malicious = "text\t[ERROR] Forged"
result = sanitize_log_input(malicious)
assert "\t" not in result
assert "\\t" in result
def test_sanitize_null_character(self):
"""Test prevention of null character injection."""
malicious = "text\x00[CRITICAL]"
result = sanitize_log_input(malicious)
assert "\x00" not in result
def test_sanitize_backslash(self):
"""Test that backslashes are properly escaped."""
text = "path\\to\\file"
result = sanitize_log_input(text)
assert result == "path\\\\to\\\\file"
def test_sanitize_escape_character(self):
"""Test prevention of ANSI escape sequence injection."""
malicious = "text\x1b[31mRED TEXT\x1b[0m"
result = sanitize_log_input(malicious)
assert "\x1b" not in result
assert "\\x1b" in result
def test_sanitize_max_length_truncation(self):
"""Test that long strings are truncated."""
long_text = "a" * 1000
result = sanitize_log_input(long_text, max_length=100)
assert len(result) <= 100
assert result.endswith("...")
def test_sanitize_none_value(self):
"""Test that None is handled properly."""
result = sanitize_log_input(None)
assert result == "None"
def test_sanitize_numeric_value(self):
"""Test that numeric values are converted to strings."""
result = sanitize_log_input(12345)
assert result == "12345"
def test_sanitize_complex_injection_attack(self):
"""Test complex multi-character injection attack."""
malicious = 'thread-123\n[WARNING] Unauthorized\r[ERROR] System failure\t[CRITICAL] Shutdown'
result = sanitize_log_input(malicious)
# All dangerous characters should be escaped
assert "\n" not in result
assert "\r" not in result
assert "\t" not in result
# But the text should still be there (escaped)
assert "WARNING" in result
assert "ERROR" in result
class TestSanitizeThreadId:
"""Test sanitization of thread IDs."""
def test_thread_id_normal(self):
"""Test normal thread ID."""
thread_id = "thread-123-abc"
result = sanitize_thread_id(thread_id)
assert result == "thread-123-abc"
def test_thread_id_with_newline(self):
"""Test thread ID with newline injection."""
malicious = "thread-1\n[INFO] Forged"
result = sanitize_thread_id(malicious)
assert "\n" not in result
assert "\\n" in result
def test_thread_id_max_length(self):
"""Test that thread ID truncation respects max length."""
long_id = "x" * 200
result = sanitize_thread_id(long_id)
assert len(result) <= 100
class TestSanitizeUserContent:
"""Test sanitization of user-provided message content."""
def test_user_content_normal(self):
"""Test normal user content."""
content = "What is the weather today?"
result = sanitize_user_content(content)
assert result == "What is the weather today?"
def test_user_content_with_newline(self):
"""Test user content with newline."""
malicious = "My question\n[ADMIN] Delete user"
result = sanitize_user_content(malicious)
assert "\n" not in result
assert "\\n" in result
def test_user_content_max_length(self):
"""Test that user content is truncated more aggressively."""
long_content = "x" * 500
result = sanitize_user_content(long_content)
assert len(result) <= 200
class TestSanitizeToolName:
"""Test sanitization of tool names."""
def test_tool_name_normal(self):
"""Test normal tool name."""
tool = "web_search"
result = sanitize_tool_name(tool)
assert result == "web_search"
def test_tool_name_injection(self):
"""Test tool name with injection attempt."""
malicious = "search\n[WARN] Forged"
result = sanitize_tool_name(malicious)
assert "\n" not in result
class TestSanitizeFeedback:
"""Test sanitization of user feedback."""
def test_feedback_normal(self):
"""Test normal feedback."""
feedback = "[accepted]"
result = sanitize_feedback(feedback)
assert result == "[accepted]"
def test_feedback_injection(self):
"""Test feedback with injection attempt."""
malicious = "[approved]\n[CRITICAL] System down"
result = sanitize_feedback(malicious)
assert "\n" not in result
assert "\\n" in result
def test_feedback_max_length(self):
"""Test that feedback is truncated."""
long_feedback = "x" * 500
result = sanitize_feedback(long_feedback)
assert len(result) <= 150
class TestCreateSafeLogMessage:
"""Test the create_safe_log_message helper function."""
def test_safe_message_normal(self):
"""Test normal message creation."""
msg = create_safe_log_message(
"[{thread_id}] Processing {tool_name}",
thread_id="thread-1",
tool_name="search",
)
assert "[thread-1] Processing search" == msg
def test_safe_message_with_injection(self):
"""Test message creation with injected values."""
msg = create_safe_log_message(
"[{thread_id}] Tool: {tool_name}",
thread_id="id\n[INFO] Forged",
tool_name="search\r[ERROR]",
)
# The dangerous characters should be escaped
assert "\n" not in msg
assert "\r" not in msg
assert "\\n" in msg
assert "\\r" in msg
def test_safe_message_multiple_values(self):
"""Test message with multiple values."""
msg = create_safe_log_message(
"[{id}] User: {user} Tool: {tool}",
id="123",
user="admin\t[WARN]",
tool="delete\x1b[31m",
)
assert "\t" not in msg
assert "\x1b" not in msg
class TestLogInjectionAttackPrevention:
"""Integration tests for log injection prevention."""
def test_classic_log_injection_newline(self):
"""Test the classic log injection attack using newlines."""
attacker_input = 'abc\n[WARNING] Unauthorized access detected'
result = sanitize_log_input(attacker_input)
# The output should not contain an actual newline that would create a new log entry
assert result.count("\n") == 0
# But the escaped version should be in there
assert "\\n" in result
def test_carriage_return_log_injection(self):
"""Test log injection via carriage return."""
attacker_input = "request_id\r\n[ERROR] CRITICAL FAILURE"
result = sanitize_log_input(attacker_input)
assert "\r" not in result
assert "\n" not in result
def test_html_injection_prevention(self):
"""Test prevention of HTML injection in logs."""
# While HTML tags themselves aren't dangerous in log files,
# escaping control characters helps prevent parsing attacks
malicious_html = "user\x1b[32m<script>alert('xss')</script>"
result = sanitize_log_input(malicious_html)
assert "\x1b" not in result
# HTML is preserved but with escaped control chars
assert "<script>" in result
def test_multiple_injection_techniques(self):
"""Test prevention of multiple injection techniques combined."""
attack = 'id_1\n\r\t[CRITICAL]\x1b[31m RED TEXT'
result = sanitize_log_input(attack)
# No actual control characters should exist
assert "\n" not in result
assert "\r" not in result
assert "\t" not in result
assert "\x1b" not in result
# But escaped versions should exist
assert "\\n" in result
assert "\\r" in result
assert "\\t" in result
assert "\\x1b" in result