mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-21 21:24:46 +08:00
security: add log injection attack prevention with input sanitization (#667)
* security: add log injection attack prevention with input sanitization - Created src/utils/log_sanitizer.py to sanitize user-controlled input before logging - Prevents log injection attacks using newlines, tabs, carriage returns, etc. - Escapes dangerous characters: \n, \r, \t, \0, \x1b - Provides specialized functions for different input types: - sanitize_log_input: general purpose sanitization - sanitize_thread_id: for user-provided thread IDs - sanitize_user_content: for user messages (more aggressive truncation) - sanitize_agent_name: for agent identifiers - sanitize_tool_name: for tool names - sanitize_feedback: for user interrupt feedback - create_safe_log_message: template-based safe message creation - Updated src/server/app.py to sanitize all user input in logging: - Thread IDs from request parameter - Message content from user - Agent names and node information - Tool names and feedback - Updated src/agents/tool_interceptor.py to sanitize: - Tool names during execution - User feedback during interrupt handling - Tool input data - Added 29 comprehensive unit tests covering: - Classic newline injection attacks - Carriage return injection - Tab and null character injection - HTML/ANSI escape sequence injection - Combined multi-character attacks - Truncation and length limits Fixes potential log forgery vulnerability where malicious users could inject fake log entries via unsanitized input containing control characters.
This commit is contained in:
@@ -55,6 +55,13 @@ from src.server.rag_request import (
|
||||
)
|
||||
from src.tools import VolcengineTTS
|
||||
from src.utils.json_utils import sanitize_args
|
||||
from src.utils.log_sanitizer import (
|
||||
sanitize_agent_name,
|
||||
sanitize_log_input,
|
||||
sanitize_thread_id,
|
||||
sanitize_tool_name,
|
||||
sanitize_user_content,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -333,9 +340,13 @@ def _process_initial_messages(message, thread_id):
|
||||
|
||||
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
||||
"""Process a single message chunk and yield appropriate events."""
|
||||
|
||||
agent_name = _get_agent_name(agent, message_metadata)
|
||||
logger.debug(f"[{thread_id}] _process_message_chunk started for agent_name={agent_name}")
|
||||
logger.debug(f"[{thread_id}] Extracted agent_name: {agent_name}")
|
||||
safe_agent_name = sanitize_agent_name(agent_name)
|
||||
safe_thread_id = sanitize_thread_id(thread_id)
|
||||
safe_agent = sanitize_agent_name(agent)
|
||||
logger.debug(f"[{safe_thread_id}] _process_message_chunk started for agent={safe_agent_name}")
|
||||
logger.debug(f"[{safe_thread_id}] Extracted agent_name: {safe_agent_name}")
|
||||
|
||||
event_stream_message = _create_event_stream_message(
|
||||
message_chunk, message_metadata, thread_id, agent_name
|
||||
@@ -343,25 +354,29 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
||||
|
||||
if isinstance(message_chunk, ToolMessage):
|
||||
# Tool Message - Return the result of the tool call
|
||||
logger.debug(f"[{thread_id}] Processing ToolMessage")
|
||||
logger.debug(f"[{safe_thread_id}] Processing ToolMessage")
|
||||
tool_call_id = message_chunk.tool_call_id
|
||||
event_stream_message["tool_call_id"] = tool_call_id
|
||||
|
||||
# Validate tool_call_id for debugging
|
||||
if tool_call_id:
|
||||
logger.debug(f"[{thread_id}] ToolMessage with tool_call_id: {tool_call_id}")
|
||||
safe_tool_id = sanitize_log_input(tool_call_id, max_length=100)
|
||||
logger.debug(f"[{safe_thread_id}] ToolMessage with tool_call_id: {safe_tool_id}")
|
||||
else:
|
||||
logger.warning(f"[{thread_id}] ToolMessage received without tool_call_id")
|
||||
logger.warning(f"[{safe_thread_id}] ToolMessage received without tool_call_id")
|
||||
|
||||
logger.debug(f"[{thread_id}] Yielding tool_call_result event")
|
||||
logger.debug(f"[{safe_thread_id}] Yielding tool_call_result event")
|
||||
yield _make_event("tool_call_result", event_stream_message)
|
||||
elif isinstance(message_chunk, AIMessageChunk):
|
||||
# AI Message - Raw message tokens
|
||||
logger.debug(f"[{thread_id}] Processing AIMessageChunk, tool_calls={bool(message_chunk.tool_calls)}, tool_call_chunks={bool(message_chunk.tool_call_chunks)}")
|
||||
has_tool_calls = bool(message_chunk.tool_calls)
|
||||
has_chunks = bool(message_chunk.tool_call_chunks)
|
||||
logger.debug(f"[{safe_thread_id}] Processing AIMessageChunk, tool_calls={has_tool_calls}, tool_call_chunks={has_chunks}")
|
||||
|
||||
if message_chunk.tool_calls:
|
||||
# AI Message - Tool Call (complete tool calls)
|
||||
logger.debug(f"[{thread_id}] AIMessageChunk has complete tool_calls: {[tc.get('name', 'unknown') for tc in message_chunk.tool_calls]}")
|
||||
safe_tool_names = [sanitize_tool_name(tc.get('name', 'unknown')) for tc in message_chunk.tool_calls]
|
||||
logger.debug(f"[{safe_thread_id}] AIMessageChunk has complete tool_calls: {safe_tool_names}")
|
||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||
|
||||
# Process tool_call_chunks with proper index-based grouping
|
||||
@@ -370,16 +385,18 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
||||
)
|
||||
if processed_chunks:
|
||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
||||
logger.debug(
|
||||
f"[{thread_id}] Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, "
|
||||
f"[{safe_thread_id}] Tool calls: {safe_tool_names}, "
|
||||
f"Processed chunks: {len(processed_chunks)}"
|
||||
)
|
||||
|
||||
logger.debug(f"[{thread_id}] Yielding tool_calls event")
|
||||
logger.debug(f"[{safe_thread_id}] Yielding tool_calls event")
|
||||
yield _make_event("tool_calls", event_stream_message)
|
||||
elif message_chunk.tool_call_chunks:
|
||||
# AI Message - Tool Call Chunks (streaming)
|
||||
logger.debug(f"[{thread_id}] AIMessageChunk has streaming tool_call_chunks: {len(message_chunk.tool_call_chunks)} chunks")
|
||||
chunks_count = len(message_chunk.tool_call_chunks)
|
||||
logger.debug(f"[{safe_thread_id}] AIMessageChunk has streaming tool_call_chunks: {chunks_count} chunks")
|
||||
processed_chunks = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
@@ -392,26 +409,30 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
||||
|
||||
# Log index transitions to detect tool call boundaries
|
||||
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
||||
prev_name = sanitize_tool_name(prev_chunk.get('name'))
|
||||
curr_name = sanitize_tool_name(chunk.get('name'))
|
||||
logger.debug(
|
||||
f"[{thread_id}] Tool call boundary detected: "
|
||||
f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> "
|
||||
f"{current_index} ({chunk.get('name')})"
|
||||
f"[{safe_thread_id}] Tool call boundary detected: "
|
||||
f"index {prev_chunk.get('index')} ({prev_name}) -> "
|
||||
f"{current_index} ({curr_name})"
|
||||
)
|
||||
|
||||
prev_chunk = chunk
|
||||
|
||||
# Include all processed chunks in the event
|
||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
||||
logger.debug(
|
||||
f"[{thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
|
||||
f"{[c.get('name') for c in processed_chunks]}"
|
||||
f"[{safe_thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
|
||||
f"{safe_chunk_names}"
|
||||
)
|
||||
|
||||
logger.debug(f"[{thread_id}] Yielding tool_call_chunks event")
|
||||
logger.debug(f"[{safe_thread_id}] Yielding tool_call_chunks event")
|
||||
yield _make_event("tool_call_chunks", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
logger.debug(f"[{thread_id}] AIMessageChunk is raw message tokens, content_len={len(message_chunk.content) if isinstance(message_chunk.content, str) else 'unknown'}")
|
||||
content_len = len(message_chunk.content) if isinstance(message_chunk.content, str) else 0
|
||||
logger.debug(f"[{safe_thread_id}] AIMessageChunk is raw message tokens, content_len={content_len}")
|
||||
yield _make_event("message_chunk", event_stream_message)
|
||||
|
||||
|
||||
@@ -419,7 +440,8 @@ async def _stream_graph_events(
|
||||
graph_instance, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
"""Stream events from the graph and process them."""
|
||||
logger.debug(f"[{thread_id}] Starting graph event stream with agent nodes")
|
||||
safe_thread_id = sanitize_thread_id(thread_id)
|
||||
logger.debug(f"[{safe_thread_id}] Starting graph event stream with agent nodes")
|
||||
try:
|
||||
event_count = 0
|
||||
async for agent, _, event_data in graph_instance.astream(
|
||||
@@ -429,28 +451,31 @@ async def _stream_graph_events(
|
||||
subgraphs=True,
|
||||
):
|
||||
event_count += 1
|
||||
logger.debug(f"[{thread_id}] Graph event #{event_count} received from agent: {agent}")
|
||||
safe_agent = sanitize_agent_name(agent)
|
||||
logger.debug(f"[{safe_thread_id}] Graph event #{event_count} received from agent: {safe_agent}")
|
||||
|
||||
if isinstance(event_data, dict):
|
||||
if "__interrupt__" in event_data:
|
||||
logger.debug(
|
||||
f"[{thread_id}] Processing interrupt event: "
|
||||
f"[{safe_thread_id}] Processing interrupt event: "
|
||||
f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, "
|
||||
f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}"
|
||||
)
|
||||
yield _create_interrupt_event(thread_id, event_data)
|
||||
logger.debug(f"[{thread_id}] Dict event without interrupt, skipping")
|
||||
logger.debug(f"[{safe_thread_id}] Dict event without interrupt, skipping")
|
||||
continue
|
||||
|
||||
message_chunk, message_metadata = cast(
|
||||
tuple[BaseMessage, dict[str, Any]], event_data
|
||||
)
|
||||
|
||||
safe_node = sanitize_agent_name(message_metadata.get('langgraph_node', 'unknown'))
|
||||
safe_step = sanitize_log_input(message_metadata.get('langgraph_step', 'unknown'))
|
||||
logger.debug(
|
||||
f"[{thread_id}] Processing message chunk: "
|
||||
f"[{safe_thread_id}] Processing message chunk: "
|
||||
f"type={type(message_chunk).__name__}, "
|
||||
f"node={message_metadata.get('langgraph_node', 'unknown')}, "
|
||||
f"step={message_metadata.get('langgraph_step', 'unknown')}"
|
||||
f"node={safe_node}, "
|
||||
f"step={safe_step}"
|
||||
)
|
||||
|
||||
async for event in _process_message_chunk(
|
||||
@@ -458,9 +483,9 @@ async def _stream_graph_events(
|
||||
):
|
||||
yield event
|
||||
|
||||
logger.debug(f"[{thread_id}] Graph event stream completed. Total events: {event_count}")
|
||||
logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
|
||||
except Exception as e:
|
||||
logger.exception(f"[{thread_id}] Error during graph execution")
|
||||
logger.exception(f"[{safe_thread_id}] Error during graph execution")
|
||||
yield _make_event(
|
||||
"error",
|
||||
{
|
||||
@@ -488,34 +513,38 @@ async def _astream_workflow_generator(
|
||||
locale: str = "en-US",
|
||||
interrupt_before_tools: Optional[List[str]] = None,
|
||||
):
|
||||
safe_thread_id = sanitize_thread_id(thread_id)
|
||||
safe_feedback = sanitize_log_input(interrupt_feedback) if interrupt_feedback else ""
|
||||
logger.debug(
|
||||
f"[{thread_id}] _astream_workflow_generator starting: "
|
||||
f"[{safe_thread_id}] _astream_workflow_generator starting: "
|
||||
f"messages_count={len(messages)}, "
|
||||
f"auto_accepted_plan={auto_accepted_plan}, "
|
||||
f"interrupt_feedback={interrupt_feedback}, "
|
||||
f"interrupt_feedback={safe_feedback}, "
|
||||
f"interrupt_before_tools={interrupt_before_tools}"
|
||||
)
|
||||
|
||||
# Process initial messages
|
||||
logger.debug(f"[{thread_id}] Processing {len(messages)} initial messages")
|
||||
logger.debug(f"[{safe_thread_id}] Processing {len(messages)} initial messages")
|
||||
for message in messages:
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
logger.debug(f"[{thread_id}] Sending initial message to client: {message.get('content', '')[:100]}")
|
||||
safe_content = sanitize_user_content(message.get('content', ''))
|
||||
logger.debug(f"[{safe_thread_id}] Sending initial message to client: {safe_content}")
|
||||
_process_initial_messages(message, thread_id)
|
||||
|
||||
logger.debug(f"[{thread_id}] Reconstructing clarification history")
|
||||
logger.debug(f"[{safe_thread_id}] Reconstructing clarification history")
|
||||
clarification_history = reconstruct_clarification_history(messages)
|
||||
|
||||
logger.debug(f"[{thread_id}] Building clarified topic from history")
|
||||
logger.debug(f"[{safe_thread_id}] Building clarified topic from history")
|
||||
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
||||
clarification_history
|
||||
)
|
||||
latest_message_content = messages[-1]["content"] if messages else ""
|
||||
clarified_research_topic = clarified_topic or latest_message_content
|
||||
logger.debug(f"[{thread_id}] Clarified research topic: {clarified_research_topic[:100]}")
|
||||
safe_topic = sanitize_user_content(clarified_research_topic)
|
||||
logger.debug(f"[{safe_thread_id}] Clarified research topic: {safe_topic}")
|
||||
|
||||
# Prepare workflow input
|
||||
logger.debug(f"[{thread_id}] Preparing workflow input")
|
||||
logger.debug(f"[{safe_thread_id}] Preparing workflow input")
|
||||
workflow_input = {
|
||||
"messages": messages,
|
||||
"plan_iterations": 0,
|
||||
@@ -533,7 +562,7 @@ async def _astream_workflow_generator(
|
||||
}
|
||||
|
||||
if not auto_accepted_plan and interrupt_feedback:
|
||||
logger.debug(f"[{thread_id}] Creating resume command with interrupt_feedback: {interrupt_feedback}")
|
||||
logger.debug(f"[{safe_thread_id}] Creating resume command with interrupt_feedback: {safe_feedback}")
|
||||
resume_msg = f"[{interrupt_feedback}]"
|
||||
if messages:
|
||||
resume_msg += f" {messages[-1]['content']}"
|
||||
@@ -541,7 +570,7 @@ async def _astream_workflow_generator(
|
||||
|
||||
# Prepare workflow config
|
||||
logger.debug(
|
||||
f"[{thread_id}] Preparing workflow config: "
|
||||
f"[{safe_thread_id}] Preparing workflow config: "
|
||||
f"max_plan_iterations={max_plan_iterations}, "
|
||||
f"max_step_num={max_step_num}, "
|
||||
f"report_style={report_style.value}, "
|
||||
@@ -564,7 +593,7 @@ async def _astream_workflow_generator(
|
||||
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
||||
|
||||
logger.debug(
|
||||
f"[{thread_id}] Checkpoint configuration: "
|
||||
f"[{safe_thread_id}] Checkpoint configuration: "
|
||||
f"saver_enabled={checkpoint_saver}, "
|
||||
f"url_configured={bool(checkpoint_url)}"
|
||||
)
|
||||
@@ -577,48 +606,48 @@ async def _astream_workflow_generator(
|
||||
}
|
||||
if checkpoint_saver and checkpoint_url != "":
|
||||
if checkpoint_url.startswith("postgresql://"):
|
||||
logger.info(f"[{thread_id}] Starting async postgres checkpointer")
|
||||
logger.debug(f"[{thread_id}] Setting up PostgreSQL connection pool")
|
||||
logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
|
||||
logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
|
||||
async with AsyncConnectionPool(
|
||||
checkpoint_url, kwargs=connection_kwargs
|
||||
) as conn:
|
||||
logger.debug(f"[{thread_id}] Initializing AsyncPostgresSaver")
|
||||
logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
|
||||
checkpointer = AsyncPostgresSaver(conn)
|
||||
await checkpointer.setup()
|
||||
logger.debug(f"[{thread_id}] Attaching checkpointer to graph")
|
||||
logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
|
||||
graph.checkpointer = checkpointer
|
||||
graph.store = in_memory_store
|
||||
logger.debug(f"[{thread_id}] Starting to stream graph events")
|
||||
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
||||
async for event in _stream_graph_events(
|
||||
graph, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
yield event
|
||||
logger.debug(f"[{thread_id}] Graph event streaming completed")
|
||||
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
||||
|
||||
if checkpoint_url.startswith("mongodb://"):
|
||||
logger.info(f"[{thread_id}] Starting async mongodb checkpointer")
|
||||
logger.debug(f"[{thread_id}] Setting up MongoDB connection")
|
||||
logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
|
||||
logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
|
||||
async with AsyncMongoDBSaver.from_conn_string(
|
||||
checkpoint_url
|
||||
) as checkpointer:
|
||||
logger.debug(f"[{thread_id}] Attaching MongoDB checkpointer to graph")
|
||||
logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
|
||||
graph.checkpointer = checkpointer
|
||||
graph.store = in_memory_store
|
||||
logger.debug(f"[{thread_id}] Starting to stream graph events")
|
||||
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
||||
async for event in _stream_graph_events(
|
||||
graph, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
yield event
|
||||
logger.debug(f"[{thread_id}] Graph event streaming completed")
|
||||
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
||||
else:
|
||||
logger.debug(f"[{thread_id}] No checkpointer configured, using in-memory graph")
|
||||
logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
|
||||
# Use graph without MongoDB checkpointer
|
||||
logger.debug(f"[{thread_id}] Starting to stream graph events")
|
||||
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
||||
async for event in _stream_graph_events(
|
||||
graph, workflow_input, workflow_config, thread_id
|
||||
):
|
||||
yield event
|
||||
logger.debug(f"[{thread_id}] Graph event streaming completed")
|
||||
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
||||
|
||||
|
||||
def _make_event(event_type: str, data: dict[str, any]):
|
||||
|
||||
Reference in New Issue
Block a user