mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-27 07:44:48 +08:00
security: add log injection attack prevention with input sanitization (#667)
* security: add log injection attack prevention with input sanitization - Created src/utils/log_sanitizer.py to sanitize user-controlled input before logging - Prevents log injection attacks using newlines, tabs, carriage returns, etc. - Escapes dangerous characters: \n, \r, \t, \0, \x1b - Provides specialized functions for different input types: - sanitize_log_input: general purpose sanitization - sanitize_thread_id: for user-provided thread IDs - sanitize_user_content: for user messages (more aggressive truncation) - sanitize_agent_name: for agent identifiers - sanitize_tool_name: for tool names - sanitize_feedback: for user interrupt feedback - create_safe_log_message: template-based safe message creation - Updated src/server/app.py to sanitize all user input in logging: - Thread IDs from request parameter - Message content from user - Agent names and node information - Tool names and feedback - Updated src/agents/tool_interceptor.py to sanitize: - Tool names during execution - User feedback during interrupt handling - Tool input data - Added 29 comprehensive unit tests covering: - Classic newline injection attacks - Carriage return injection - Tab and null character injection - HTML/ANSI escape sequence injection - Combined multi-character attacks - Truncation and length limits Fixes potential log forgery vulnerability where malicious users could inject fake log entries via unsanitized input containing control characters.
This commit is contained in:
@@ -6,10 +6,10 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
|
from src.agents.tool_interceptor import wrap_tools_with_interceptor
|
||||||
from src.config.agents import AGENT_LLM_MAP
|
from src.config.agents import AGENT_LLM_MAP
|
||||||
from src.llms.llm import get_llm_by_type
|
from src.llms.llm import get_llm_by_type
|
||||||
from src.prompts import apply_prompt_template
|
from src.prompts import apply_prompt_template
|
||||||
from src.agents.tool_interceptor import wrap_tools_with_interceptor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,12 @@ from typing import Any, Callable, List, Optional
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import interrupt
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from src.utils.log_sanitizer import (
|
||||||
|
sanitize_feedback,
|
||||||
|
sanitize_log_input,
|
||||||
|
sanitize_tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,27 +90,30 @@ class ToolInterceptor:
|
|||||||
BaseTool: The wrapped tool with interrupt capability
|
BaseTool: The wrapped tool with interrupt capability
|
||||||
"""
|
"""
|
||||||
original_func = tool.func
|
original_func = tool.func
|
||||||
logger.debug(f"Wrapping tool '{tool.name}' with interrupt capability")
|
safe_tool_name = sanitize_tool_name(tool.name)
|
||||||
|
logger.debug(f"Wrapping tool '{safe_tool_name}' with interrupt capability")
|
||||||
|
|
||||||
def intercepted_func(*args: Any, **kwargs: Any) -> Any:
|
def intercepted_func(*args: Any, **kwargs: Any) -> Any:
|
||||||
"""Execute the tool with interrupt check."""
|
"""Execute the tool with interrupt check."""
|
||||||
tool_name = tool.name
|
tool_name = tool.name
|
||||||
logger.debug(f"[ToolInterceptor] Executing tool: {tool_name}")
|
safe_tool_name_local = sanitize_tool_name(tool_name)
|
||||||
|
logger.debug(f"[ToolInterceptor] Executing tool: {safe_tool_name_local}")
|
||||||
|
|
||||||
# Format tool input for display
|
# Format tool input for display
|
||||||
tool_input = args[0] if args else kwargs
|
tool_input = args[0] if args else kwargs
|
||||||
tool_input_repr = ToolInterceptor._format_tool_input(tool_input)
|
tool_input_repr = ToolInterceptor._format_tool_input(tool_input)
|
||||||
logger.debug(f"[ToolInterceptor] Tool input: {tool_input_repr[:200]}")
|
safe_tool_input = sanitize_log_input(tool_input_repr, max_length=100)
|
||||||
|
logger.debug(f"[ToolInterceptor] Tool input: {safe_tool_input}")
|
||||||
|
|
||||||
should_interrupt = interceptor.should_interrupt(tool_name)
|
should_interrupt = interceptor.should_interrupt(tool_name)
|
||||||
logger.debug(f"[ToolInterceptor] should_interrupt={should_interrupt} for tool '{tool_name}'")
|
logger.debug(f"[ToolInterceptor] should_interrupt={should_interrupt} for tool '{safe_tool_name_local}'")
|
||||||
|
|
||||||
if should_interrupt:
|
if should_interrupt:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[ToolInterceptor] Interrupting before tool '{tool_name}'"
|
f"[ToolInterceptor] Interrupting before tool '{safe_tool_name_local}'"
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[ToolInterceptor] Interrupt message: About to execute tool '{tool_name}' with input: {tool_input_repr[:100]}..."
|
f"[ToolInterceptor] Interrupt message: About to execute tool '{safe_tool_name_local}' with input: {safe_tool_input}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger interrupt and wait for user feedback
|
# Trigger interrupt and wait for user feedback
|
||||||
@@ -112,41 +121,43 @@ class ToolInterceptor:
|
|||||||
feedback = interrupt(
|
feedback = interrupt(
|
||||||
f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?"
|
f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?"
|
||||||
)
|
)
|
||||||
logger.debug(f"[ToolInterceptor] Interrupt returned with feedback: {f'{feedback[:100]}...' if feedback and len(feedback) > 100 else feedback if feedback else 'None'}")
|
safe_feedback = sanitize_feedback(feedback)
|
||||||
|
logger.debug(f"[ToolInterceptor] Interrupt returned with feedback: {f'{safe_feedback[:100]}...' if safe_feedback and len(safe_feedback) > 100 else safe_feedback if safe_feedback else 'None'}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[ToolInterceptor] Error during interrupt: {str(e)}")
|
logger.error(f"[ToolInterceptor] Error during interrupt: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
logger.debug(f"[ToolInterceptor] Processing feedback approval for '{tool_name}'")
|
logger.debug(f"[ToolInterceptor] Processing feedback approval for '{safe_tool_name_local}'")
|
||||||
|
|
||||||
# Check if user approved
|
# Check if user approved
|
||||||
is_approved = ToolInterceptor._parse_approval(feedback)
|
is_approved = ToolInterceptor._parse_approval(feedback)
|
||||||
logger.info(f"[ToolInterceptor] Tool '{tool_name}' approval decision: {is_approved}")
|
logger.info(f"[ToolInterceptor] Tool '{safe_tool_name_local}' approval decision: {is_approved}")
|
||||||
|
|
||||||
if not is_approved:
|
if not is_approved:
|
||||||
logger.warning(f"[ToolInterceptor] User rejected execution of tool '{tool_name}'")
|
logger.warning(f"[ToolInterceptor] User rejected execution of tool '{safe_tool_name_local}'")
|
||||||
return {
|
return {
|
||||||
"error": f"Tool execution rejected by user",
|
"error": f"Tool execution rejected by user",
|
||||||
"tool": tool_name,
|
"tool": tool_name,
|
||||||
"status": "rejected",
|
"status": "rejected",
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"[ToolInterceptor] User approved execution of tool '{tool_name}', proceeding")
|
logger.info(f"[ToolInterceptor] User approved execution of tool '{safe_tool_name_local}', proceeding")
|
||||||
|
|
||||||
# Execute the original tool
|
# Execute the original tool
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[ToolInterceptor] Calling original function for tool '{tool_name}'")
|
logger.debug(f"[ToolInterceptor] Calling original function for tool '{safe_tool_name_local}'")
|
||||||
result = original_func(*args, **kwargs)
|
result = original_func(*args, **kwargs)
|
||||||
logger.info(f"[ToolInterceptor] Tool '{tool_name}' execution completed successfully")
|
logger.info(f"[ToolInterceptor] Tool '{safe_tool_name_local}' execution completed successfully")
|
||||||
logger.debug(f"[ToolInterceptor] Tool result length: {len(str(result))}")
|
result_len = len(str(result))
|
||||||
|
logger.debug(f"[ToolInterceptor] Tool result length: {result_len}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[ToolInterceptor] Error executing tool '{tool_name}': {str(e)}")
|
logger.error(f"[ToolInterceptor] Error executing tool '{safe_tool_name_local}': {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Replace the function and update the tool
|
# Replace the function and update the tool
|
||||||
# Use object.__setattr__ to bypass Pydantic validation
|
# Use object.__setattr__ to bypass Pydantic validation
|
||||||
logger.debug(f"Attaching intercepted function to tool '{tool.name}'")
|
logger.debug(f"Attaching intercepted function to tool '{safe_tool_name}'")
|
||||||
object.__setattr__(tool, "func", intercepted_func)
|
object.__setattr__(tool, "func", intercepted_func)
|
||||||
return tool
|
return tool
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from readabilipy import simple_json_from_html_string
|
from readabilipy import simple_json_from_html_string
|
||||||
|
|
||||||
from .article import Article
|
from .article import Article
|
||||||
|
|||||||
@@ -55,6 +55,13 @@ from src.server.rag_request import (
|
|||||||
)
|
)
|
||||||
from src.tools import VolcengineTTS
|
from src.tools import VolcengineTTS
|
||||||
from src.utils.json_utils import sanitize_args
|
from src.utils.json_utils import sanitize_args
|
||||||
|
from src.utils.log_sanitizer import (
|
||||||
|
sanitize_agent_name,
|
||||||
|
sanitize_log_input,
|
||||||
|
sanitize_thread_id,
|
||||||
|
sanitize_tool_name,
|
||||||
|
sanitize_user_content,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -333,9 +340,13 @@ def _process_initial_messages(message, thread_id):
|
|||||||
|
|
||||||
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
||||||
"""Process a single message chunk and yield appropriate events."""
|
"""Process a single message chunk and yield appropriate events."""
|
||||||
|
|
||||||
agent_name = _get_agent_name(agent, message_metadata)
|
agent_name = _get_agent_name(agent, message_metadata)
|
||||||
logger.debug(f"[{thread_id}] _process_message_chunk started for agent_name={agent_name}")
|
safe_agent_name = sanitize_agent_name(agent_name)
|
||||||
logger.debug(f"[{thread_id}] Extracted agent_name: {agent_name}")
|
safe_thread_id = sanitize_thread_id(thread_id)
|
||||||
|
safe_agent = sanitize_agent_name(agent)
|
||||||
|
logger.debug(f"[{safe_thread_id}] _process_message_chunk started for agent={safe_agent_name}")
|
||||||
|
logger.debug(f"[{safe_thread_id}] Extracted agent_name: {safe_agent_name}")
|
||||||
|
|
||||||
event_stream_message = _create_event_stream_message(
|
event_stream_message = _create_event_stream_message(
|
||||||
message_chunk, message_metadata, thread_id, agent_name
|
message_chunk, message_metadata, thread_id, agent_name
|
||||||
@@ -343,25 +354,29 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
|||||||
|
|
||||||
if isinstance(message_chunk, ToolMessage):
|
if isinstance(message_chunk, ToolMessage):
|
||||||
# Tool Message - Return the result of the tool call
|
# Tool Message - Return the result of the tool call
|
||||||
logger.debug(f"[{thread_id}] Processing ToolMessage")
|
logger.debug(f"[{safe_thread_id}] Processing ToolMessage")
|
||||||
tool_call_id = message_chunk.tool_call_id
|
tool_call_id = message_chunk.tool_call_id
|
||||||
event_stream_message["tool_call_id"] = tool_call_id
|
event_stream_message["tool_call_id"] = tool_call_id
|
||||||
|
|
||||||
# Validate tool_call_id for debugging
|
# Validate tool_call_id for debugging
|
||||||
if tool_call_id:
|
if tool_call_id:
|
||||||
logger.debug(f"[{thread_id}] ToolMessage with tool_call_id: {tool_call_id}")
|
safe_tool_id = sanitize_log_input(tool_call_id, max_length=100)
|
||||||
|
logger.debug(f"[{safe_thread_id}] ToolMessage with tool_call_id: {safe_tool_id}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{thread_id}] ToolMessage received without tool_call_id")
|
logger.warning(f"[{safe_thread_id}] ToolMessage received without tool_call_id")
|
||||||
|
|
||||||
logger.debug(f"[{thread_id}] Yielding tool_call_result event")
|
logger.debug(f"[{safe_thread_id}] Yielding tool_call_result event")
|
||||||
yield _make_event("tool_call_result", event_stream_message)
|
yield _make_event("tool_call_result", event_stream_message)
|
||||||
elif isinstance(message_chunk, AIMessageChunk):
|
elif isinstance(message_chunk, AIMessageChunk):
|
||||||
# AI Message - Raw message tokens
|
# AI Message - Raw message tokens
|
||||||
logger.debug(f"[{thread_id}] Processing AIMessageChunk, tool_calls={bool(message_chunk.tool_calls)}, tool_call_chunks={bool(message_chunk.tool_call_chunks)}")
|
has_tool_calls = bool(message_chunk.tool_calls)
|
||||||
|
has_chunks = bool(message_chunk.tool_call_chunks)
|
||||||
|
logger.debug(f"[{safe_thread_id}] Processing AIMessageChunk, tool_calls={has_tool_calls}, tool_call_chunks={has_chunks}")
|
||||||
|
|
||||||
if message_chunk.tool_calls:
|
if message_chunk.tool_calls:
|
||||||
# AI Message - Tool Call (complete tool calls)
|
# AI Message - Tool Call (complete tool calls)
|
||||||
logger.debug(f"[{thread_id}] AIMessageChunk has complete tool_calls: {[tc.get('name', 'unknown') for tc in message_chunk.tool_calls]}")
|
safe_tool_names = [sanitize_tool_name(tc.get('name', 'unknown')) for tc in message_chunk.tool_calls]
|
||||||
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk has complete tool_calls: {safe_tool_names}")
|
||||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||||
|
|
||||||
# Process tool_call_chunks with proper index-based grouping
|
# Process tool_call_chunks with proper index-based grouping
|
||||||
@@ -370,16 +385,18 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
|||||||
)
|
)
|
||||||
if processed_chunks:
|
if processed_chunks:
|
||||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||||
|
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, "
|
f"[{safe_thread_id}] Tool calls: {safe_tool_names}, "
|
||||||
f"Processed chunks: {len(processed_chunks)}"
|
f"Processed chunks: {len(processed_chunks)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[{thread_id}] Yielding tool_calls event")
|
logger.debug(f"[{safe_thread_id}] Yielding tool_calls event")
|
||||||
yield _make_event("tool_calls", event_stream_message)
|
yield _make_event("tool_calls", event_stream_message)
|
||||||
elif message_chunk.tool_call_chunks:
|
elif message_chunk.tool_call_chunks:
|
||||||
# AI Message - Tool Call Chunks (streaming)
|
# AI Message - Tool Call Chunks (streaming)
|
||||||
logger.debug(f"[{thread_id}] AIMessageChunk has streaming tool_call_chunks: {len(message_chunk.tool_call_chunks)} chunks")
|
chunks_count = len(message_chunk.tool_call_chunks)
|
||||||
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk has streaming tool_call_chunks: {chunks_count} chunks")
|
||||||
processed_chunks = _process_tool_call_chunks(
|
processed_chunks = _process_tool_call_chunks(
|
||||||
message_chunk.tool_call_chunks
|
message_chunk.tool_call_chunks
|
||||||
)
|
)
|
||||||
@@ -392,26 +409,30 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
|||||||
|
|
||||||
# Log index transitions to detect tool call boundaries
|
# Log index transitions to detect tool call boundaries
|
||||||
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
||||||
|
prev_name = sanitize_tool_name(prev_chunk.get('name'))
|
||||||
|
curr_name = sanitize_tool_name(chunk.get('name'))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Tool call boundary detected: "
|
f"[{safe_thread_id}] Tool call boundary detected: "
|
||||||
f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> "
|
f"index {prev_chunk.get('index')} ({prev_name}) -> "
|
||||||
f"{current_index} ({chunk.get('name')})"
|
f"{current_index} ({curr_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_chunk = chunk
|
prev_chunk = chunk
|
||||||
|
|
||||||
# Include all processed chunks in the event
|
# Include all processed chunks in the event
|
||||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||||
|
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
|
f"[{safe_thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
|
||||||
f"{[c.get('name') for c in processed_chunks]}"
|
f"{safe_chunk_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[{thread_id}] Yielding tool_call_chunks event")
|
logger.debug(f"[{safe_thread_id}] Yielding tool_call_chunks event")
|
||||||
yield _make_event("tool_call_chunks", event_stream_message)
|
yield _make_event("tool_call_chunks", event_stream_message)
|
||||||
else:
|
else:
|
||||||
# AI Message - Raw message tokens
|
# AI Message - Raw message tokens
|
||||||
logger.debug(f"[{thread_id}] AIMessageChunk is raw message tokens, content_len={len(message_chunk.content) if isinstance(message_chunk.content, str) else 'unknown'}")
|
content_len = len(message_chunk.content) if isinstance(message_chunk.content, str) else 0
|
||||||
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk is raw message tokens, content_len={content_len}")
|
||||||
yield _make_event("message_chunk", event_stream_message)
|
yield _make_event("message_chunk", event_stream_message)
|
||||||
|
|
||||||
|
|
||||||
@@ -419,7 +440,8 @@ async def _stream_graph_events(
|
|||||||
graph_instance, workflow_input, workflow_config, thread_id
|
graph_instance, workflow_input, workflow_config, thread_id
|
||||||
):
|
):
|
||||||
"""Stream events from the graph and process them."""
|
"""Stream events from the graph and process them."""
|
||||||
logger.debug(f"[{thread_id}] Starting graph event stream with agent nodes")
|
safe_thread_id = sanitize_thread_id(thread_id)
|
||||||
|
logger.debug(f"[{safe_thread_id}] Starting graph event stream with agent nodes")
|
||||||
try:
|
try:
|
||||||
event_count = 0
|
event_count = 0
|
||||||
async for agent, _, event_data in graph_instance.astream(
|
async for agent, _, event_data in graph_instance.astream(
|
||||||
@@ -429,28 +451,31 @@ async def _stream_graph_events(
|
|||||||
subgraphs=True,
|
subgraphs=True,
|
||||||
):
|
):
|
||||||
event_count += 1
|
event_count += 1
|
||||||
logger.debug(f"[{thread_id}] Graph event #{event_count} received from agent: {agent}")
|
safe_agent = sanitize_agent_name(agent)
|
||||||
|
logger.debug(f"[{safe_thread_id}] Graph event #{event_count} received from agent: {safe_agent}")
|
||||||
|
|
||||||
if isinstance(event_data, dict):
|
if isinstance(event_data, dict):
|
||||||
if "__interrupt__" in event_data:
|
if "__interrupt__" in event_data:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Processing interrupt event: "
|
f"[{safe_thread_id}] Processing interrupt event: "
|
||||||
f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, "
|
f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, "
|
||||||
f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}"
|
f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}"
|
||||||
)
|
)
|
||||||
yield _create_interrupt_event(thread_id, event_data)
|
yield _create_interrupt_event(thread_id, event_data)
|
||||||
logger.debug(f"[{thread_id}] Dict event without interrupt, skipping")
|
logger.debug(f"[{safe_thread_id}] Dict event without interrupt, skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
message_chunk, message_metadata = cast(
|
message_chunk, message_metadata = cast(
|
||||||
tuple[BaseMessage, dict[str, Any]], event_data
|
tuple[BaseMessage, dict[str, Any]], event_data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
safe_node = sanitize_agent_name(message_metadata.get('langgraph_node', 'unknown'))
|
||||||
|
safe_step = sanitize_log_input(message_metadata.get('langgraph_step', 'unknown'))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Processing message chunk: "
|
f"[{safe_thread_id}] Processing message chunk: "
|
||||||
f"type={type(message_chunk).__name__}, "
|
f"type={type(message_chunk).__name__}, "
|
||||||
f"node={message_metadata.get('langgraph_node', 'unknown')}, "
|
f"node={safe_node}, "
|
||||||
f"step={message_metadata.get('langgraph_step', 'unknown')}"
|
f"step={safe_step}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async for event in _process_message_chunk(
|
async for event in _process_message_chunk(
|
||||||
@@ -458,9 +483,9 @@ async def _stream_graph_events(
|
|||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
logger.debug(f"[{thread_id}] Graph event stream completed. Total events: {event_count}")
|
logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"[{thread_id}] Error during graph execution")
|
logger.exception(f"[{safe_thread_id}] Error during graph execution")
|
||||||
yield _make_event(
|
yield _make_event(
|
||||||
"error",
|
"error",
|
||||||
{
|
{
|
||||||
@@ -488,34 +513,38 @@ async def _astream_workflow_generator(
|
|||||||
locale: str = "en-US",
|
locale: str = "en-US",
|
||||||
interrupt_before_tools: Optional[List[str]] = None,
|
interrupt_before_tools: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
safe_thread_id = sanitize_thread_id(thread_id)
|
||||||
|
safe_feedback = sanitize_log_input(interrupt_feedback) if interrupt_feedback else ""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] _astream_workflow_generator starting: "
|
f"[{safe_thread_id}] _astream_workflow_generator starting: "
|
||||||
f"messages_count={len(messages)}, "
|
f"messages_count={len(messages)}, "
|
||||||
f"auto_accepted_plan={auto_accepted_plan}, "
|
f"auto_accepted_plan={auto_accepted_plan}, "
|
||||||
f"interrupt_feedback={interrupt_feedback}, "
|
f"interrupt_feedback={safe_feedback}, "
|
||||||
f"interrupt_before_tools={interrupt_before_tools}"
|
f"interrupt_before_tools={interrupt_before_tools}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process initial messages
|
# Process initial messages
|
||||||
logger.debug(f"[{thread_id}] Processing {len(messages)} initial messages")
|
logger.debug(f"[{safe_thread_id}] Processing {len(messages)} initial messages")
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message, dict) and "content" in message:
|
if isinstance(message, dict) and "content" in message:
|
||||||
logger.debug(f"[{thread_id}] Sending initial message to client: {message.get('content', '')[:100]}")
|
safe_content = sanitize_user_content(message.get('content', ''))
|
||||||
|
logger.debug(f"[{safe_thread_id}] Sending initial message to client: {safe_content}")
|
||||||
_process_initial_messages(message, thread_id)
|
_process_initial_messages(message, thread_id)
|
||||||
|
|
||||||
logger.debug(f"[{thread_id}] Reconstructing clarification history")
|
logger.debug(f"[{safe_thread_id}] Reconstructing clarification history")
|
||||||
clarification_history = reconstruct_clarification_history(messages)
|
clarification_history = reconstruct_clarification_history(messages)
|
||||||
|
|
||||||
logger.debug(f"[{thread_id}] Building clarified topic from history")
|
logger.debug(f"[{safe_thread_id}] Building clarified topic from history")
|
||||||
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
||||||
clarification_history
|
clarification_history
|
||||||
)
|
)
|
||||||
latest_message_content = messages[-1]["content"] if messages else ""
|
latest_message_content = messages[-1]["content"] if messages else ""
|
||||||
clarified_research_topic = clarified_topic or latest_message_content
|
clarified_research_topic = clarified_topic or latest_message_content
|
||||||
logger.debug(f"[{thread_id}] Clarified research topic: {clarified_research_topic[:100]}")
|
safe_topic = sanitize_user_content(clarified_research_topic)
|
||||||
|
logger.debug(f"[{safe_thread_id}] Clarified research topic: {safe_topic}")
|
||||||
|
|
||||||
# Prepare workflow input
|
# Prepare workflow input
|
||||||
logger.debug(f"[{thread_id}] Preparing workflow input")
|
logger.debug(f"[{safe_thread_id}] Preparing workflow input")
|
||||||
workflow_input = {
|
workflow_input = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"plan_iterations": 0,
|
"plan_iterations": 0,
|
||||||
@@ -533,7 +562,7 @@ async def _astream_workflow_generator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not auto_accepted_plan and interrupt_feedback:
|
if not auto_accepted_plan and interrupt_feedback:
|
||||||
logger.debug(f"[{thread_id}] Creating resume command with interrupt_feedback: {interrupt_feedback}")
|
logger.debug(f"[{safe_thread_id}] Creating resume command with interrupt_feedback: {safe_feedback}")
|
||||||
resume_msg = f"[{interrupt_feedback}]"
|
resume_msg = f"[{interrupt_feedback}]"
|
||||||
if messages:
|
if messages:
|
||||||
resume_msg += f" {messages[-1]['content']}"
|
resume_msg += f" {messages[-1]['content']}"
|
||||||
@@ -541,7 +570,7 @@ async def _astream_workflow_generator(
|
|||||||
|
|
||||||
# Prepare workflow config
|
# Prepare workflow config
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Preparing workflow config: "
|
f"[{safe_thread_id}] Preparing workflow config: "
|
||||||
f"max_plan_iterations={max_plan_iterations}, "
|
f"max_plan_iterations={max_plan_iterations}, "
|
||||||
f"max_step_num={max_step_num}, "
|
f"max_step_num={max_step_num}, "
|
||||||
f"report_style={report_style.value}, "
|
f"report_style={report_style.value}, "
|
||||||
@@ -564,7 +593,7 @@ async def _astream_workflow_generator(
|
|||||||
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{thread_id}] Checkpoint configuration: "
|
f"[{safe_thread_id}] Checkpoint configuration: "
|
||||||
f"saver_enabled={checkpoint_saver}, "
|
f"saver_enabled={checkpoint_saver}, "
|
||||||
f"url_configured={bool(checkpoint_url)}"
|
f"url_configured={bool(checkpoint_url)}"
|
||||||
)
|
)
|
||||||
@@ -577,48 +606,48 @@ async def _astream_workflow_generator(
|
|||||||
}
|
}
|
||||||
if checkpoint_saver and checkpoint_url != "":
|
if checkpoint_saver and checkpoint_url != "":
|
||||||
if checkpoint_url.startswith("postgresql://"):
|
if checkpoint_url.startswith("postgresql://"):
|
||||||
logger.info(f"[{thread_id}] Starting async postgres checkpointer")
|
logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
|
||||||
logger.debug(f"[{thread_id}] Setting up PostgreSQL connection pool")
|
logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
|
||||||
async with AsyncConnectionPool(
|
async with AsyncConnectionPool(
|
||||||
checkpoint_url, kwargs=connection_kwargs
|
checkpoint_url, kwargs=connection_kwargs
|
||||||
) as conn:
|
) as conn:
|
||||||
logger.debug(f"[{thread_id}] Initializing AsyncPostgresSaver")
|
logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
|
||||||
checkpointer = AsyncPostgresSaver(conn)
|
checkpointer = AsyncPostgresSaver(conn)
|
||||||
await checkpointer.setup()
|
await checkpointer.setup()
|
||||||
logger.debug(f"[{thread_id}] Attaching checkpointer to graph")
|
logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
|
||||||
graph.checkpointer = checkpointer
|
graph.checkpointer = checkpointer
|
||||||
graph.store = in_memory_store
|
graph.store = in_memory_store
|
||||||
logger.debug(f"[{thread_id}] Starting to stream graph events")
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
||||||
async for event in _stream_graph_events(
|
async for event in _stream_graph_events(
|
||||||
graph, workflow_input, workflow_config, thread_id
|
graph, workflow_input, workflow_config, thread_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
logger.debug(f"[{thread_id}] Graph event streaming completed")
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
||||||
|
|
||||||
if checkpoint_url.startswith("mongodb://"):
|
if checkpoint_url.startswith("mongodb://"):
|
||||||
logger.info(f"[{thread_id}] Starting async mongodb checkpointer")
|
logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
|
||||||
logger.debug(f"[{thread_id}] Setting up MongoDB connection")
|
logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
|
||||||
async with AsyncMongoDBSaver.from_conn_string(
|
async with AsyncMongoDBSaver.from_conn_string(
|
||||||
checkpoint_url
|
checkpoint_url
|
||||||
) as checkpointer:
|
) as checkpointer:
|
||||||
logger.debug(f"[{thread_id}] Attaching MongoDB checkpointer to graph")
|
logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
|
||||||
graph.checkpointer = checkpointer
|
graph.checkpointer = checkpointer
|
||||||
graph.store = in_memory_store
|
graph.store = in_memory_store
|
||||||
logger.debug(f"[{thread_id}] Starting to stream graph events")
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
||||||
async for event in _stream_graph_events(
|
async for event in _stream_graph_events(
|
||||||
graph, workflow_input, workflow_config, thread_id
|
graph, workflow_input, workflow_config, thread_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
logger.debug(f"[{thread_id}] Graph event streaming completed")
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[{thread_id}] No checkpointer configured, using in-memory graph")
|
logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
|
||||||
# Use graph without MongoDB checkpointer
|
# Use graph without MongoDB checkpointer
|
||||||
logger.debug(f"[{thread_id}] Starting to stream graph events")
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
||||||
async for event in _stream_graph_events(
|
async for event in _stream_graph_events(
|
||||||
graph, workflow_input, workflow_config, thread_id
|
graph, workflow_input, workflow_config, thread_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
logger.debug(f"[{thread_id}] Graph event streaming completed")
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
||||||
|
|
||||||
|
|
||||||
def _make_event(event_type: str, data: dict[str, any]):
|
def _make_event(event_type: str, data: dict[str, any]):
|
||||||
|
|||||||
186
src/utils/log_sanitizer.py
Normal file
186
src/utils/log_sanitizer.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
"""
|
||||||
|
Log sanitization utilities to prevent log injection attacks.
|
||||||
|
|
||||||
|
This module provides functions to sanitize user-controlled input before
|
||||||
|
logging to prevent attackers from forging log entries through:
|
||||||
|
- Newline injection (\n)
|
||||||
|
- HTML injection (for HTML logs)
|
||||||
|
- Special character sequences that could be misinterpreted
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_log_input(value: Any, max_length: int = 500) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize user-controlled input for safe logging.
|
||||||
|
|
||||||
|
Replaces dangerous characters (newlines, tabs, carriage returns, etc.)
|
||||||
|
with their escaped representations to prevent log injection attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The input value to sanitize (any type)
|
||||||
|
max_length: Maximum length of output string (truncates if exceeded)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized string safe for logging
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> sanitize_log_input("normal text")
|
||||||
|
'normal text'
|
||||||
|
|
||||||
|
>>> sanitize_log_input("malicious\n[INFO] fake entry")
|
||||||
|
'malicious\\n[INFO] fake entry'
|
||||||
|
|
||||||
|
>>> sanitize_log_input("tab\there")
|
||||||
|
'tab\\there'
|
||||||
|
|
||||||
|
>>> sanitize_log_input(None)
|
||||||
|
'None'
|
||||||
|
|
||||||
|
>>> long_text = "a" * 1000
|
||||||
|
>>> result = sanitize_log_input(long_text, max_length=100)
|
||||||
|
>>> len(result) <= 100
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Convert to string
|
||||||
|
string_value = str(value)
|
||||||
|
|
||||||
|
# Replace dangerous characters with their escaped representations
|
||||||
|
# Order matters: escape backslashes first to avoid double-escaping
|
||||||
|
replacements = {
|
||||||
|
"\\": "\\\\", # Backslash (must be first)
|
||||||
|
"\n": "\\n", # Newline - prevents creating new log entries
|
||||||
|
"\r": "\\r", # Carriage return
|
||||||
|
"\t": "\\t", # Tab
|
||||||
|
"\x00": "\\0", # Null character
|
||||||
|
"\x1b": "\\x1b", # Escape character (used in ANSI sequences)
|
||||||
|
}
|
||||||
|
|
||||||
|
for char, replacement in replacements.items():
|
||||||
|
string_value = string_value.replace(char, replacement)
|
||||||
|
|
||||||
|
# Remove other control characters (ASCII 0-31 except those already handled)
|
||||||
|
# These are rarely useful in logs and could be exploited
|
||||||
|
string_value = re.sub(r"[\x00-\x08\x0b-\x0c\x0e-\x1f]", "", string_value)
|
||||||
|
|
||||||
|
# Truncate if too long (prevent log flooding)
|
||||||
|
if len(string_value) > max_length:
|
||||||
|
string_value = string_value[: max_length - 3] + "..."
|
||||||
|
|
||||||
|
return string_value
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_thread_id(thread_id: Any) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize thread_id for logging.
|
||||||
|
|
||||||
|
Thread IDs should be alphanumeric with hyphens and underscores,
|
||||||
|
but we sanitize to be defensive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized thread ID
|
||||||
|
"""
|
||||||
|
return sanitize_log_input(thread_id, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_user_content(content: Any) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize user-provided message content for logging.
|
||||||
|
|
||||||
|
User messages can be arbitrary length, so we truncate more aggressively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The user content to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized user content
|
||||||
|
"""
|
||||||
|
return sanitize_log_input(content, max_length=200)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_agent_name(agent_name: Any) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize agent name for logging.
|
||||||
|
|
||||||
|
Agent names should be simple identifiers, but we sanitize to be defensive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: The agent name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized agent name
|
||||||
|
"""
|
||||||
|
return sanitize_log_input(agent_name, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tool_name(tool_name: Any) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize tool name for logging.
|
||||||
|
|
||||||
|
Tool names should be simple identifiers, but we sanitize to be defensive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: The tool name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized tool name
|
||||||
|
"""
|
||||||
|
return sanitize_log_input(tool_name, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_feedback(feedback: Any) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize user feedback for logging.
|
||||||
|
|
||||||
|
Feedback can be arbitrary text from interrupts, so sanitize carefully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feedback: The feedback to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized feedback (truncated more aggressively)
|
||||||
|
"""
|
||||||
|
return sanitize_log_input(feedback, max_length=150)
|
||||||
|
|
||||||
|
|
||||||
|
def create_safe_log_message(template: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Create a safe log message by sanitizing all values.
|
||||||
|
|
||||||
|
Uses a template string with keyword arguments, sanitizing each value
|
||||||
|
before substitution to prevent log injection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: Template string with {key} placeholders
|
||||||
|
**kwargs: Key-value pairs to substitute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Safe log message
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> msg = create_safe_log_message(
|
||||||
|
... "[{thread_id}] Processing {tool_name}",
|
||||||
|
... thread_id="abc\\n[INFO]",
|
||||||
|
... tool_name="my_tool"
|
||||||
|
... )
|
||||||
|
>>> "[abc\\\\n[INFO]] Processing my_tool" in msg
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
# Sanitize all values
|
||||||
|
safe_kwargs = {
|
||||||
|
key: sanitize_log_input(value) for key, value in kwargs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Substitute into template
|
||||||
|
return template.format(**safe_kwargs)
|
||||||
@@ -11,12 +11,12 @@ Tests the complete flow of selective tool interrupts including:
|
|||||||
- Resume mechanism after interrupt
|
- Resume mechanism after interrupt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock, call
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
import pytest
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from src.agents.agents import create_agent
|
from src.agents.agents import create_agent
|
||||||
from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor
|
from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_core.tools import tool, BaseTool
|
|
||||||
|
|
||||||
from src.agents.tool_interceptor import (
|
from src.agents.tool_interceptor import (
|
||||||
ToolInterceptor,
|
ToolInterceptor,
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
from src.crawler.jina_client import JinaClient
|
from src.crawler.jina_client import JinaClient
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from src.crawler.readability_extractor import ReadabilityExtractor
|
from src.crawler.readability_extractor import ReadabilityExtractor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
from src.graph.nodes import validate_and_fix_plan
|
from src.graph.nodes import validate_and_fix_plan
|
||||||
|
|
||||||
|
|||||||
@@ -10,13 +10,14 @@ tool names from being concatenated when multiple tool calls happen in sequence.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import pytest
|
import os
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
# Import the functions to test
|
# Import the functions to test
|
||||||
# Note: We need to import from the app module
|
# Note: We need to import from the app module
|
||||||
import sys
|
import sys
|
||||||
import os
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
# Add src directory to path for imports
|
# Add src directory to path for imports
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
|
||||||
|
|||||||
@@ -3,7 +3,11 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.utils.json_utils import repair_json_output, sanitize_tool_response, _extract_json_from_content
|
from src.utils.json_utils import (
|
||||||
|
_extract_json_from_content,
|
||||||
|
repair_json_output,
|
||||||
|
sanitize_tool_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestRepairJsonOutput:
|
class TestRepairJsonOutput:
|
||||||
|
|||||||
268
tests/unit/utils/test_log_sanitizer.py
Normal file
268
tests/unit/utils/test_log_sanitizer.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for log sanitization utilities.
|
||||||
|
|
||||||
|
This test file verifies that the log sanitizer properly prevents log injection attacks
|
||||||
|
by escaping dangerous characters in user-controlled input before logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.utils.log_sanitizer import (
|
||||||
|
create_safe_log_message,
|
||||||
|
sanitize_agent_name,
|
||||||
|
sanitize_feedback,
|
||||||
|
sanitize_log_input,
|
||||||
|
sanitize_thread_id,
|
||||||
|
sanitize_tool_name,
|
||||||
|
sanitize_user_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeLogInput:
|
||||||
|
"""Test the main sanitize_log_input function."""
|
||||||
|
|
||||||
|
def test_sanitize_normal_text(self):
|
||||||
|
"""Test that normal text is preserved."""
|
||||||
|
text = "normal text"
|
||||||
|
result = sanitize_log_input(text)
|
||||||
|
assert result == "normal text"
|
||||||
|
|
||||||
|
def test_sanitize_newline_injection(self):
|
||||||
|
"""Test prevention of newline injection attack."""
|
||||||
|
malicious = "abc\n[INFO] Forged log entry"
|
||||||
|
result = sanitize_log_input(malicious)
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "[INFO]" in result # The attack text is preserved but escaped
|
||||||
|
assert "\\n" in result # Newline is escaped
|
||||||
|
|
||||||
|
def test_sanitize_carriage_return(self):
|
||||||
|
"""Test prevention of carriage return injection."""
|
||||||
|
malicious = "text\r[WARN] Forged entry"
|
||||||
|
result = sanitize_log_input(malicious)
|
||||||
|
assert "\r" not in result
|
||||||
|
assert "\\r" in result
|
||||||
|
|
||||||
|
def test_sanitize_tab_character(self):
|
||||||
|
"""Test prevention of tab character injection."""
|
||||||
|
malicious = "text\t[ERROR] Forged"
|
||||||
|
result = sanitize_log_input(malicious)
|
||||||
|
assert "\t" not in result
|
||||||
|
assert "\\t" in result
|
||||||
|
|
||||||
|
def test_sanitize_null_character(self):
|
||||||
|
"""Test prevention of null character injection."""
|
||||||
|
malicious = "text\x00[CRITICAL]"
|
||||||
|
result = sanitize_log_input(malicious)
|
||||||
|
assert "\x00" not in result
|
||||||
|
|
||||||
|
def test_sanitize_backslash(self):
|
||||||
|
"""Test that backslashes are properly escaped."""
|
||||||
|
text = "path\\to\\file"
|
||||||
|
result = sanitize_log_input(text)
|
||||||
|
assert result == "path\\\\to\\\\file"
|
||||||
|
|
||||||
|
def test_sanitize_escape_character(self):
|
||||||
|
"""Test prevention of ANSI escape sequence injection."""
|
||||||
|
malicious = "text\x1b[31mRED TEXT\x1b[0m"
|
||||||
|
result = sanitize_log_input(malicious)
|
||||||
|
assert "\x1b" not in result
|
||||||
|
assert "\\x1b" in result
|
||||||
|
|
||||||
|
def test_sanitize_max_length_truncation(self):
|
||||||
|
"""Test that long strings are truncated."""
|
||||||
|
long_text = "a" * 1000
|
||||||
|
result = sanitize_log_input(long_text, max_length=100)
|
||||||
|
assert len(result) <= 100
|
||||||
|
assert result.endswith("...")
|
||||||
|
|
||||||
|
def test_sanitize_none_value(self):
|
||||||
|
"""Test that None is handled properly."""
|
||||||
|
result = sanitize_log_input(None)
|
||||||
|
assert result == "None"
|
||||||
|
|
||||||
|
def test_sanitize_numeric_value(self):
|
||||||
|
"""Test that numeric values are converted to strings."""
|
||||||
|
result = sanitize_log_input(12345)
|
||||||
|
assert result == "12345"
|
||||||
|
|
||||||
|
def test_sanitize_complex_injection_attack(self):
|
||||||
|
"""Test complex multi-character injection attack."""
|
||||||
|
malicious = 'thread-123\n[WARNING] Unauthorized\r[ERROR] System failure\t[CRITICAL] Shutdown'
|
||||||
|
result = sanitize_log_input(malicious)
|
||||||
|
# All dangerous characters should be escaped
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "\r" not in result
|
||||||
|
assert "\t" not in result
|
||||||
|
# But the text should still be there (escaped)
|
||||||
|
assert "WARNING" in result
|
||||||
|
assert "ERROR" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeThreadId:
|
||||||
|
"""Test sanitization of thread IDs."""
|
||||||
|
|
||||||
|
def test_thread_id_normal(self):
|
||||||
|
"""Test normal thread ID."""
|
||||||
|
thread_id = "thread-123-abc"
|
||||||
|
result = sanitize_thread_id(thread_id)
|
||||||
|
assert result == "thread-123-abc"
|
||||||
|
|
||||||
|
def test_thread_id_with_newline(self):
|
||||||
|
"""Test thread ID with newline injection."""
|
||||||
|
malicious = "thread-1\n[INFO] Forged"
|
||||||
|
result = sanitize_thread_id(malicious)
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "\\n" in result
|
||||||
|
|
||||||
|
def test_thread_id_max_length(self):
|
||||||
|
"""Test that thread ID truncation respects max length."""
|
||||||
|
long_id = "x" * 200
|
||||||
|
result = sanitize_thread_id(long_id)
|
||||||
|
assert len(result) <= 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeUserContent:
|
||||||
|
"""Test sanitization of user-provided message content."""
|
||||||
|
|
||||||
|
def test_user_content_normal(self):
|
||||||
|
"""Test normal user content."""
|
||||||
|
content = "What is the weather today?"
|
||||||
|
result = sanitize_user_content(content)
|
||||||
|
assert result == "What is the weather today?"
|
||||||
|
|
||||||
|
def test_user_content_with_newline(self):
|
||||||
|
"""Test user content with newline."""
|
||||||
|
malicious = "My question\n[ADMIN] Delete user"
|
||||||
|
result = sanitize_user_content(malicious)
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "\\n" in result
|
||||||
|
|
||||||
|
def test_user_content_max_length(self):
|
||||||
|
"""Test that user content is truncated more aggressively."""
|
||||||
|
long_content = "x" * 500
|
||||||
|
result = sanitize_user_content(long_content)
|
||||||
|
assert len(result) <= 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeToolName:
|
||||||
|
"""Test sanitization of tool names."""
|
||||||
|
|
||||||
|
def test_tool_name_normal(self):
|
||||||
|
"""Test normal tool name."""
|
||||||
|
tool = "web_search"
|
||||||
|
result = sanitize_tool_name(tool)
|
||||||
|
assert result == "web_search"
|
||||||
|
|
||||||
|
def test_tool_name_injection(self):
|
||||||
|
"""Test tool name with injection attempt."""
|
||||||
|
malicious = "search\n[WARN] Forged"
|
||||||
|
result = sanitize_tool_name(malicious)
|
||||||
|
assert "\n" not in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeFeedback:
|
||||||
|
"""Test sanitization of user feedback."""
|
||||||
|
|
||||||
|
def test_feedback_normal(self):
|
||||||
|
"""Test normal feedback."""
|
||||||
|
feedback = "[accepted]"
|
||||||
|
result = sanitize_feedback(feedback)
|
||||||
|
assert result == "[accepted]"
|
||||||
|
|
||||||
|
def test_feedback_injection(self):
|
||||||
|
"""Test feedback with injection attempt."""
|
||||||
|
malicious = "[approved]\n[CRITICAL] System down"
|
||||||
|
result = sanitize_feedback(malicious)
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "\\n" in result
|
||||||
|
|
||||||
|
def test_feedback_max_length(self):
|
||||||
|
"""Test that feedback is truncated."""
|
||||||
|
long_feedback = "x" * 500
|
||||||
|
result = sanitize_feedback(long_feedback)
|
||||||
|
assert len(result) <= 150
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSafeLogMessage:
|
||||||
|
"""Test the create_safe_log_message helper function."""
|
||||||
|
|
||||||
|
def test_safe_message_normal(self):
|
||||||
|
"""Test normal message creation."""
|
||||||
|
msg = create_safe_log_message(
|
||||||
|
"[{thread_id}] Processing {tool_name}",
|
||||||
|
thread_id="thread-1",
|
||||||
|
tool_name="search",
|
||||||
|
)
|
||||||
|
assert "[thread-1] Processing search" == msg
|
||||||
|
|
||||||
|
def test_safe_message_with_injection(self):
|
||||||
|
"""Test message creation with injected values."""
|
||||||
|
msg = create_safe_log_message(
|
||||||
|
"[{thread_id}] Tool: {tool_name}",
|
||||||
|
thread_id="id\n[INFO] Forged",
|
||||||
|
tool_name="search\r[ERROR]",
|
||||||
|
)
|
||||||
|
# The dangerous characters should be escaped
|
||||||
|
assert "\n" not in msg
|
||||||
|
assert "\r" not in msg
|
||||||
|
assert "\\n" in msg
|
||||||
|
assert "\\r" in msg
|
||||||
|
|
||||||
|
def test_safe_message_multiple_values(self):
|
||||||
|
"""Test message with multiple values."""
|
||||||
|
msg = create_safe_log_message(
|
||||||
|
"[{id}] User: {user} Tool: {tool}",
|
||||||
|
id="123",
|
||||||
|
user="admin\t[WARN]",
|
||||||
|
tool="delete\x1b[31m",
|
||||||
|
)
|
||||||
|
assert "\t" not in msg
|
||||||
|
assert "\x1b" not in msg
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogInjectionAttackPrevention:
|
||||||
|
"""Integration tests for log injection prevention."""
|
||||||
|
|
||||||
|
def test_classic_log_injection_newline(self):
|
||||||
|
"""Test the classic log injection attack using newlines."""
|
||||||
|
attacker_input = 'abc\n[WARNING] Unauthorized access detected'
|
||||||
|
result = sanitize_log_input(attacker_input)
|
||||||
|
# The output should not contain an actual newline that would create a new log entry
|
||||||
|
assert result.count("\n") == 0
|
||||||
|
# But the escaped version should be in there
|
||||||
|
assert "\\n" in result
|
||||||
|
|
||||||
|
def test_carriage_return_log_injection(self):
|
||||||
|
"""Test log injection via carriage return."""
|
||||||
|
attacker_input = "request_id\r\n[ERROR] CRITICAL FAILURE"
|
||||||
|
result = sanitize_log_input(attacker_input)
|
||||||
|
assert "\r" not in result
|
||||||
|
assert "\n" not in result
|
||||||
|
|
||||||
|
def test_html_injection_prevention(self):
|
||||||
|
"""Test prevention of HTML injection in logs."""
|
||||||
|
# While HTML tags themselves aren't dangerous in log files,
|
||||||
|
# escaping control characters helps prevent parsing attacks
|
||||||
|
malicious_html = "user\x1b[32m<script>alert('xss')</script>"
|
||||||
|
result = sanitize_log_input(malicious_html)
|
||||||
|
assert "\x1b" not in result
|
||||||
|
# HTML is preserved but with escaped control chars
|
||||||
|
assert "<script>" in result
|
||||||
|
|
||||||
|
def test_multiple_injection_techniques(self):
|
||||||
|
"""Test prevention of multiple injection techniques combined."""
|
||||||
|
attack = 'id_1\n\r\t[CRITICAL]\x1b[31m RED TEXT'
|
||||||
|
result = sanitize_log_input(attack)
|
||||||
|
# No actual control characters should exist
|
||||||
|
assert "\n" not in result
|
||||||
|
assert "\r" not in result
|
||||||
|
assert "\t" not in result
|
||||||
|
assert "\x1b" not in result
|
||||||
|
# But escaped versions should exist
|
||||||
|
assert "\\n" in result
|
||||||
|
assert "\\r" in result
|
||||||
|
assert "\\t" in result
|
||||||
|
assert "\\x1b" in result
|
||||||
Reference in New Issue
Block a user