fix: prevent tool name concatenation in consecutive tool calls to fix #523 (#654)

- Implement index-based grouping of tool call chunks in _process_tool_call_chunks()
- Add _validate_tool_call_chunks() for debug logging and validation
- Enhance _process_message_chunk() with tool call ID validation and boundary detection
- Add comprehensive unit tests (17 tests) for tool call chunk processing
- Fix issue where tool names were incorrectly concatenated (e.g., 'web_searchweb_search')
- Ensure chunks from different tool calls (different indices) remain properly separated
- Add detailed logging for debugging tool call streaming issues

* update the code with suggestions of reviewing
This commit is contained in:
Willem Jiang
2025-10-24 22:26:25 +08:00
committed by GitHub
parent 36bf5c9ccd
commit f2be4d6af1
2 changed files with 470 additions and 10 deletions

View File

@@ -132,19 +132,122 @@ async def chat_stream(request: ChatRequest):
)
def _validate_tool_call_chunks(tool_call_chunks):
"""Validate and log tool call chunk structure for debugging."""
if not tool_call_chunks:
return
logger.debug(f"Validating tool_call_chunks: count={len(tool_call_chunks)}")
indices_seen = set()
tool_ids_seen = set()
for i, chunk in enumerate(tool_call_chunks):
index = chunk.get("index")
tool_id = chunk.get("id")
name = chunk.get("name", "")
has_args = "args" in chunk
logger.debug(
f"Chunk {i}: index={index}, id={tool_id}, name={name}, "
f"has_args={has_args}, type={chunk.get('type')}"
)
if index is not None:
indices_seen.add(index)
if tool_id:
tool_ids_seen.add(tool_id)
if len(indices_seen) > 1:
logger.debug(
f"Multiple indices detected: {sorted(indices_seen)} - "
f"This may indicate consecutive tool calls"
)
def _process_tool_call_chunks(tool_call_chunks):
"""Process tool call chunks and sanitize arguments."""
"""
Process tool call chunks with proper index-based grouping.
This function handles the concatenation of tool call chunks that belong
to the same tool call (same index) while properly segregating chunks
from different tool calls (different indices).
The issue: In streaming, LangChain's ToolCallChunk concatenates string
attributes (name, args) when chunks have the same index. We need to:
1. Group chunks by index
2. Detect index collisions with different tool names
3. Accumulate arguments for the same index
4. Return properly segregated tool calls
"""
if not tool_call_chunks:
return []
_validate_tool_call_chunks(tool_call_chunks)
chunks = []
chunk_by_index = {} # Group chunks by index to handle streaming accumulation
for chunk in tool_call_chunks:
chunks.append(
{
index = chunk.get("index")
chunk_id = chunk.get("id")
if index is not None:
# Create or update entry for this index
if index not in chunk_by_index:
chunk_by_index[index] = {
"name": "",
"args": "",
"id": chunk_id or "",
"index": index,
"type": chunk.get("type", ""),
}
# Validate and accumulate tool name
chunk_name = chunk.get("name", "")
if chunk_name:
stored_name = chunk_by_index[index]["name"]
# Check for index collision with different tool names
if stored_name and stored_name != chunk_name:
logger.warning(
f"Tool name mismatch detected at index {index}: "
f"'{stored_name}' != '{chunk_name}'. "
f"This may indicate a streaming artifact or consecutive tool calls "
f"with the same index assignment."
)
# Keep the first name to prevent concatenation
else:
chunk_by_index[index]["name"] = chunk_name
# Update ID if new one provided
if chunk_id and not chunk_by_index[index]["id"]:
chunk_by_index[index]["id"] = chunk_id
# Accumulate arguments
if chunk.get("args"):
chunk_by_index[index]["args"] += chunk.get("args", "")
else:
# Handle chunks without explicit index (edge case)
logger.debug(f"Chunk without index encountered: {chunk}")
chunks.append({
"name": chunk.get("name", ""),
"args": sanitize_args(chunk.get("args", "")),
"id": chunk.get("id", ""),
"index": chunk.get("index", 0),
"index": 0,
"type": chunk.get("type", ""),
}
})
# Convert indexed chunks to list, sorted by index for proper order
for index in sorted(chunk_by_index.keys()):
chunk_data = chunk_by_index[index]
chunk_data["args"] = sanitize_args(chunk_data["args"])
chunks.append(chunk_data)
logger.debug(
f"Processed tool call: index={index}, name={chunk_data['name']}, "
f"id={chunk_data['id']}"
)
return chunks
@@ -236,22 +339,63 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
tool_call_id = message_chunk.tool_call_id
event_stream_message["tool_call_id"] = tool_call_id
# Validate tool_call_id for debugging
if tool_call_id:
logger.debug(f"Processing ToolMessage with tool_call_id: {tool_call_id}")
else:
logger.warning("ToolMessage received without tool_call_id")
yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk):
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
# AI Message - Tool Call (complete tool calls)
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
# Process tool_call_chunks with proper index-based grouping
processed_chunks = _process_tool_call_chunks(
message_chunk.tool_call_chunks
)
if processed_chunks:
event_stream_message["tool_call_chunks"] = processed_chunks
logger.debug(
f"Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, "
f"Processed chunks: {len(processed_chunks)}"
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
# AI Message - Tool Call Chunks (streaming)
processed_chunks = _process_tool_call_chunks(
message_chunk.tool_call_chunks
)
# Emit separate events for chunks with different indices (tool call boundaries)
if processed_chunks:
prev_chunk = None
for chunk in processed_chunks:
current_index = chunk.get("index")
# Log index transitions to detect tool call boundaries
if prev_chunk is not None and current_index != prev_chunk.get("index"):
logger.debug(
f"Tool call boundary detected: "
f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> "
f"{current_index} ({chunk.get('name')})"
)
prev_chunk = chunk
# Include all processed chunks in the event
event_stream_message["tool_call_chunks"] = processed_chunks
logger.debug(
f"Streamed {len(processed_chunks)} tool call chunk(s): "
f"{[c.get('name') for c in processed_chunks]}"
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens