mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-04 06:32:13 +08:00
- Implement index-based grouping of tool call chunks in _process_tool_call_chunks() - Add _validate_tool_call_chunks() for debug logging and validation - Enhance _process_message_chunk() with tool call ID validation and boundary detection - Add comprehensive unit tests (17 tests) for tool call chunk processing - Fix issue where tool names were incorrectly concatenated (e.g., 'web_searchweb_search') - Ensure chunks from different tool calls (different indices) remain properly separated - Add detailed logging for debugging tool call streaming issues * update the code with suggestions of reviewing
This commit is contained in:
@@ -132,19 +132,122 @@ async def chat_stream(request: ChatRequest):
|
||||
)
|
||||
|
||||
|
||||
def _validate_tool_call_chunks(tool_call_chunks):
|
||||
"""Validate and log tool call chunk structure for debugging."""
|
||||
if not tool_call_chunks:
|
||||
return
|
||||
|
||||
logger.debug(f"Validating tool_call_chunks: count={len(tool_call_chunks)}")
|
||||
|
||||
indices_seen = set()
|
||||
tool_ids_seen = set()
|
||||
|
||||
for i, chunk in enumerate(tool_call_chunks):
|
||||
index = chunk.get("index")
|
||||
tool_id = chunk.get("id")
|
||||
name = chunk.get("name", "")
|
||||
has_args = "args" in chunk
|
||||
|
||||
logger.debug(
|
||||
f"Chunk {i}: index={index}, id={tool_id}, name={name}, "
|
||||
f"has_args={has_args}, type={chunk.get('type')}"
|
||||
)
|
||||
|
||||
if index is not None:
|
||||
indices_seen.add(index)
|
||||
if tool_id:
|
||||
tool_ids_seen.add(tool_id)
|
||||
|
||||
if len(indices_seen) > 1:
|
||||
logger.debug(
|
||||
f"Multiple indices detected: {sorted(indices_seen)} - "
|
||||
f"This may indicate consecutive tool calls"
|
||||
)
|
||||
|
||||
|
||||
def _process_tool_call_chunks(tool_call_chunks):
|
||||
"""Process tool call chunks and sanitize arguments."""
|
||||
"""
|
||||
Process tool call chunks with proper index-based grouping.
|
||||
|
||||
This function handles the concatenation of tool call chunks that belong
|
||||
to the same tool call (same index) while properly segregating chunks
|
||||
from different tool calls (different indices).
|
||||
|
||||
The issue: In streaming, LangChain's ToolCallChunk concatenates string
|
||||
attributes (name, args) when chunks have the same index. We need to:
|
||||
1. Group chunks by index
|
||||
2. Detect index collisions with different tool names
|
||||
3. Accumulate arguments for the same index
|
||||
4. Return properly segregated tool calls
|
||||
"""
|
||||
if not tool_call_chunks:
|
||||
return []
|
||||
|
||||
_validate_tool_call_chunks(tool_call_chunks)
|
||||
|
||||
chunks = []
|
||||
chunk_by_index = {} # Group chunks by index to handle streaming accumulation
|
||||
|
||||
for chunk in tool_call_chunks:
|
||||
chunks.append(
|
||||
{
|
||||
index = chunk.get("index")
|
||||
chunk_id = chunk.get("id")
|
||||
|
||||
if index is not None:
|
||||
# Create or update entry for this index
|
||||
if index not in chunk_by_index:
|
||||
chunk_by_index[index] = {
|
||||
"name": "",
|
||||
"args": "",
|
||||
"id": chunk_id or "",
|
||||
"index": index,
|
||||
"type": chunk.get("type", ""),
|
||||
}
|
||||
|
||||
# Validate and accumulate tool name
|
||||
chunk_name = chunk.get("name", "")
|
||||
if chunk_name:
|
||||
stored_name = chunk_by_index[index]["name"]
|
||||
|
||||
# Check for index collision with different tool names
|
||||
if stored_name and stored_name != chunk_name:
|
||||
logger.warning(
|
||||
f"Tool name mismatch detected at index {index}: "
|
||||
f"'{stored_name}' != '{chunk_name}'. "
|
||||
f"This may indicate a streaming artifact or consecutive tool calls "
|
||||
f"with the same index assignment."
|
||||
)
|
||||
# Keep the first name to prevent concatenation
|
||||
else:
|
||||
chunk_by_index[index]["name"] = chunk_name
|
||||
|
||||
# Update ID if new one provided
|
||||
if chunk_id and not chunk_by_index[index]["id"]:
|
||||
chunk_by_index[index]["id"] = chunk_id
|
||||
|
||||
# Accumulate arguments
|
||||
if chunk.get("args"):
|
||||
chunk_by_index[index]["args"] += chunk.get("args", "")
|
||||
else:
|
||||
# Handle chunks without explicit index (edge case)
|
||||
logger.debug(f"Chunk without index encountered: {chunk}")
|
||||
chunks.append({
|
||||
"name": chunk.get("name", ""),
|
||||
"args": sanitize_args(chunk.get("args", "")),
|
||||
"id": chunk.get("id", ""),
|
||||
"index": chunk.get("index", 0),
|
||||
"index": 0,
|
||||
"type": chunk.get("type", ""),
|
||||
}
|
||||
})
|
||||
|
||||
# Convert indexed chunks to list, sorted by index for proper order
|
||||
for index in sorted(chunk_by_index.keys()):
|
||||
chunk_data = chunk_by_index[index]
|
||||
chunk_data["args"] = sanitize_args(chunk_data["args"])
|
||||
chunks.append(chunk_data)
|
||||
logger.debug(
|
||||
f"Processed tool call: index={index}, name={chunk_data['name']}, "
|
||||
f"id={chunk_data['id']}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -236,22 +339,63 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
||||
|
||||
if isinstance(message_chunk, ToolMessage):
|
||||
# Tool Message - Return the result of the tool call
|
||||
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
||||
tool_call_id = message_chunk.tool_call_id
|
||||
event_stream_message["tool_call_id"] = tool_call_id
|
||||
|
||||
# Validate tool_call_id for debugging
|
||||
if tool_call_id:
|
||||
logger.debug(f"Processing ToolMessage with tool_call_id: {tool_call_id}")
|
||||
else:
|
||||
logger.warning("ToolMessage received without tool_call_id")
|
||||
|
||||
yield _make_event("tool_call_result", event_stream_message)
|
||||
elif isinstance(message_chunk, AIMessageChunk):
|
||||
# AI Message - Raw message tokens
|
||||
if message_chunk.tool_calls:
|
||||
# AI Message - Tool Call
|
||||
# AI Message - Tool Call (complete tool calls)
|
||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
||||
|
||||
# Process tool_call_chunks with proper index-based grouping
|
||||
processed_chunks = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
if processed_chunks:
|
||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||
logger.debug(
|
||||
f"Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, "
|
||||
f"Processed chunks: {len(processed_chunks)}"
|
||||
)
|
||||
|
||||
yield _make_event("tool_calls", event_stream_message)
|
||||
elif message_chunk.tool_call_chunks:
|
||||
# AI Message - Tool Call Chunks
|
||||
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
||||
# AI Message - Tool Call Chunks (streaming)
|
||||
processed_chunks = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
|
||||
# Emit separate events for chunks with different indices (tool call boundaries)
|
||||
if processed_chunks:
|
||||
prev_chunk = None
|
||||
for chunk in processed_chunks:
|
||||
current_index = chunk.get("index")
|
||||
|
||||
# Log index transitions to detect tool call boundaries
|
||||
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
||||
logger.debug(
|
||||
f"Tool call boundary detected: "
|
||||
f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> "
|
||||
f"{current_index} ({chunk.get('name')})"
|
||||
)
|
||||
|
||||
prev_chunk = chunk
|
||||
|
||||
# Include all processed chunks in the event
|
||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||
logger.debug(
|
||||
f"Streamed {len(processed_chunks)} tool call chunk(s): "
|
||||
f"{[c.get('name') for c in processed_chunks]}"
|
||||
)
|
||||
|
||||
yield _make_event("tool_call_chunks", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
|
||||
Reference in New Issue
Block a user