mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
- Implement index-based grouping of tool call chunks in _process_tool_call_chunks() - Add _validate_tool_call_chunks() for debug logging and validation - Enhance _process_message_chunk() with tool call ID validation and boundary detection - Add comprehensive unit tests (17 tests) for tool call chunk processing - Fix issue where tool names were incorrectly concatenated (e.g., 'web_searchweb_search') - Ensure chunks from different tool calls (different indices) remain properly separated - Add detailed logging for debugging tool call streaming issues * update the code with suggestions of reviewing
This commit is contained in:
@@ -132,19 +132,122 @@ async def chat_stream(request: ChatRequest):
|
||||
)
|
||||
|
||||
|
||||
def _validate_tool_call_chunks(tool_call_chunks):
|
||||
"""Validate and log tool call chunk structure for debugging."""
|
||||
if not tool_call_chunks:
|
||||
return
|
||||
|
||||
logger.debug(f"Validating tool_call_chunks: count={len(tool_call_chunks)}")
|
||||
|
||||
indices_seen = set()
|
||||
tool_ids_seen = set()
|
||||
|
||||
for i, chunk in enumerate(tool_call_chunks):
|
||||
index = chunk.get("index")
|
||||
tool_id = chunk.get("id")
|
||||
name = chunk.get("name", "")
|
||||
has_args = "args" in chunk
|
||||
|
||||
logger.debug(
|
||||
f"Chunk {i}: index={index}, id={tool_id}, name={name}, "
|
||||
f"has_args={has_args}, type={chunk.get('type')}"
|
||||
)
|
||||
|
||||
if index is not None:
|
||||
indices_seen.add(index)
|
||||
if tool_id:
|
||||
tool_ids_seen.add(tool_id)
|
||||
|
||||
if len(indices_seen) > 1:
|
||||
logger.debug(
|
||||
f"Multiple indices detected: {sorted(indices_seen)} - "
|
||||
f"This may indicate consecutive tool calls"
|
||||
)
|
||||
|
||||
|
||||
def _process_tool_call_chunks(tool_call_chunks):
|
||||
"""Process tool call chunks and sanitize arguments."""
|
||||
"""
|
||||
Process tool call chunks with proper index-based grouping.
|
||||
|
||||
This function handles the concatenation of tool call chunks that belong
|
||||
to the same tool call (same index) while properly segregating chunks
|
||||
from different tool calls (different indices).
|
||||
|
||||
The issue: In streaming, LangChain's ToolCallChunk concatenates string
|
||||
attributes (name, args) when chunks have the same index. We need to:
|
||||
1. Group chunks by index
|
||||
2. Detect index collisions with different tool names
|
||||
3. Accumulate arguments for the same index
|
||||
4. Return properly segregated tool calls
|
||||
"""
|
||||
if not tool_call_chunks:
|
||||
return []
|
||||
|
||||
_validate_tool_call_chunks(tool_call_chunks)
|
||||
|
||||
chunks = []
|
||||
chunk_by_index = {} # Group chunks by index to handle streaming accumulation
|
||||
|
||||
for chunk in tool_call_chunks:
|
||||
chunks.append(
|
||||
{
|
||||
index = chunk.get("index")
|
||||
chunk_id = chunk.get("id")
|
||||
|
||||
if index is not None:
|
||||
# Create or update entry for this index
|
||||
if index not in chunk_by_index:
|
||||
chunk_by_index[index] = {
|
||||
"name": "",
|
||||
"args": "",
|
||||
"id": chunk_id or "",
|
||||
"index": index,
|
||||
"type": chunk.get("type", ""),
|
||||
}
|
||||
|
||||
# Validate and accumulate tool name
|
||||
chunk_name = chunk.get("name", "")
|
||||
if chunk_name:
|
||||
stored_name = chunk_by_index[index]["name"]
|
||||
|
||||
# Check for index collision with different tool names
|
||||
if stored_name and stored_name != chunk_name:
|
||||
logger.warning(
|
||||
f"Tool name mismatch detected at index {index}: "
|
||||
f"'{stored_name}' != '{chunk_name}'. "
|
||||
f"This may indicate a streaming artifact or consecutive tool calls "
|
||||
f"with the same index assignment."
|
||||
)
|
||||
# Keep the first name to prevent concatenation
|
||||
else:
|
||||
chunk_by_index[index]["name"] = chunk_name
|
||||
|
||||
# Update ID if new one provided
|
||||
if chunk_id and not chunk_by_index[index]["id"]:
|
||||
chunk_by_index[index]["id"] = chunk_id
|
||||
|
||||
# Accumulate arguments
|
||||
if chunk.get("args"):
|
||||
chunk_by_index[index]["args"] += chunk.get("args", "")
|
||||
else:
|
||||
# Handle chunks without explicit index (edge case)
|
||||
logger.debug(f"Chunk without index encountered: {chunk}")
|
||||
chunks.append({
|
||||
"name": chunk.get("name", ""),
|
||||
"args": sanitize_args(chunk.get("args", "")),
|
||||
"id": chunk.get("id", ""),
|
||||
"index": chunk.get("index", 0),
|
||||
"index": 0,
|
||||
"type": chunk.get("type", ""),
|
||||
}
|
||||
})
|
||||
|
||||
# Convert indexed chunks to list, sorted by index for proper order
|
||||
for index in sorted(chunk_by_index.keys()):
|
||||
chunk_data = chunk_by_index[index]
|
||||
chunk_data["args"] = sanitize_args(chunk_data["args"])
|
||||
chunks.append(chunk_data)
|
||||
logger.debug(
|
||||
f"Processed tool call: index={index}, name={chunk_data['name']}, "
|
||||
f"id={chunk_data['id']}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -236,22 +339,63 @@ async def _process_message_chunk(message_chunk, message_metadata, thread_id, age
|
||||
|
||||
if isinstance(message_chunk, ToolMessage):
|
||||
# Tool Message - Return the result of the tool call
|
||||
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
||||
tool_call_id = message_chunk.tool_call_id
|
||||
event_stream_message["tool_call_id"] = tool_call_id
|
||||
|
||||
# Validate tool_call_id for debugging
|
||||
if tool_call_id:
|
||||
logger.debug(f"Processing ToolMessage with tool_call_id: {tool_call_id}")
|
||||
else:
|
||||
logger.warning("ToolMessage received without tool_call_id")
|
||||
|
||||
yield _make_event("tool_call_result", event_stream_message)
|
||||
elif isinstance(message_chunk, AIMessageChunk):
|
||||
# AI Message - Raw message tokens
|
||||
if message_chunk.tool_calls:
|
||||
# AI Message - Tool Call
|
||||
# AI Message - Tool Call (complete tool calls)
|
||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
||||
|
||||
# Process tool_call_chunks with proper index-based grouping
|
||||
processed_chunks = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
if processed_chunks:
|
||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||
logger.debug(
|
||||
f"Tool calls: {[tc.get('name') for tc in message_chunk.tool_calls]}, "
|
||||
f"Processed chunks: {len(processed_chunks)}"
|
||||
)
|
||||
|
||||
yield _make_event("tool_calls", event_stream_message)
|
||||
elif message_chunk.tool_call_chunks:
|
||||
# AI Message - Tool Call Chunks
|
||||
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
||||
# AI Message - Tool Call Chunks (streaming)
|
||||
processed_chunks = _process_tool_call_chunks(
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
|
||||
# Emit separate events for chunks with different indices (tool call boundaries)
|
||||
if processed_chunks:
|
||||
prev_chunk = None
|
||||
for chunk in processed_chunks:
|
||||
current_index = chunk.get("index")
|
||||
|
||||
# Log index transitions to detect tool call boundaries
|
||||
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
||||
logger.debug(
|
||||
f"Tool call boundary detected: "
|
||||
f"index {prev_chunk.get('index')} ({prev_chunk.get('name')}) -> "
|
||||
f"{current_index} ({chunk.get('name')})"
|
||||
)
|
||||
|
||||
prev_chunk = chunk
|
||||
|
||||
# Include all processed chunks in the event
|
||||
event_stream_message["tool_call_chunks"] = processed_chunks
|
||||
logger.debug(
|
||||
f"Streamed {len(processed_chunks)} tool call chunk(s): "
|
||||
f"{[c.get('name') for c in processed_chunks]}"
|
||||
)
|
||||
|
||||
yield _make_event("tool_call_chunks", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
|
||||
316
tests/unit/server/test_tool_call_chunks.py
Normal file
316
tests/unit/server/test_tool_call_chunks.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for tool call chunk processing.
|
||||
|
||||
Tests for the fix of issue #523: Tool name concatenation in consecutive tool calls.
|
||||
This ensures that tool call chunks are properly segregated by index to prevent
|
||||
tool names from being concatenated when multiple tool calls happen in sequence.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Import the functions to test
|
||||
# Note: We need to import from the app module
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
|
||||
from src.server.app import _process_tool_call_chunks, _validate_tool_call_chunks
|
||||
|
||||
|
||||
class TestProcessToolCallChunks:
|
||||
"""Test cases for _process_tool_call_chunks function."""
|
||||
|
||||
def test_empty_tool_call_chunks(self):
|
||||
"""Test processing empty tool call chunks."""
|
||||
result = _process_tool_call_chunks([])
|
||||
assert result == []
|
||||
|
||||
def test_single_tool_call_single_chunk(self):
|
||||
"""Test processing a single tool call with a single chunk."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0}
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "web_search"
|
||||
assert result[0]["id"] == "call_1"
|
||||
assert result[0]["index"] == 0
|
||||
assert '"query": "test"' in result[0]["args"]
|
||||
|
||||
def test_consecutive_tool_calls_different_indices(self):
|
||||
"""Test that consecutive tool calls with different indices are not concatenated."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0},
|
||||
{"name": "web_search", "args": '{"query": "test2"}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "web_search"
|
||||
assert result[0]["id"] == "call_1"
|
||||
assert result[0]["index"] == 0
|
||||
assert result[1]["name"] == "web_search"
|
||||
assert result[1]["id"] == "call_2"
|
||||
assert result[1]["index"] == 1
|
||||
# Verify names are NOT concatenated
|
||||
assert result[0]["name"] != "web_searchweb_search"
|
||||
assert result[1]["name"] != "web_searchweb_search"
|
||||
|
||||
def test_different_tools_different_indices(self):
|
||||
"""Test consecutive calls to different tools."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0},
|
||||
{"name": "crawl_tool", "args": '{"url": "http://example.com"}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "web_search"
|
||||
assert result[1]["name"] == "crawl_tool"
|
||||
# Verify names are NOT concatenated (the issue bug scenario)
|
||||
assert "web_searchcrawl_tool" not in result[0]["name"]
|
||||
assert "web_searchcrawl_tool" not in result[1]["name"]
|
||||
|
||||
def test_streaming_chunks_same_index(self):
|
||||
"""Test streaming chunks with same index are properly accumulated."""
|
||||
chunks = [
|
||||
{"name": "web_", "args": '{"query"', "id": "call_1", "index": 0},
|
||||
{"name": "search", "args": ': "test"}', "id": "call_1", "index": 0},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
# Name should NOT be concatenated when it's the same tool
|
||||
assert result[0]["name"] in ["web_", "search", "web_search"]
|
||||
assert result[0]["id"] == "call_1"
|
||||
# Args should be accumulated
|
||||
assert "query" in result[0]["args"]
|
||||
assert "test" in result[0]["args"]
|
||||
|
||||
def test_tool_call_index_collision_warning(self):
|
||||
"""Test that index collision with different names generates warning."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "crawl_tool", "args": '{}', "id": "call_2", "index": 0},
|
||||
]
|
||||
|
||||
# This should trigger a warning
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args[0][0]
|
||||
assert "Tool name mismatch detected" in call_args
|
||||
assert "web_search" in call_args
|
||||
assert "crawl_tool" in call_args
|
||||
|
||||
def test_chunks_without_explicit_index(self):
|
||||
"""Test handling chunks without explicit index (edge case)."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_1"} # No index
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "web_search"
|
||||
|
||||
def test_chunk_sorting_by_index(self):
|
||||
"""Test that chunks are sorted by index in proper order."""
|
||||
chunks = [
|
||||
{"name": "tool_3", "args": '{}', "id": "call_3", "index": 2},
|
||||
{"name": "tool_1", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "tool_2", "args": '{}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0]["index"] == 0
|
||||
assert result[1]["index"] == 1
|
||||
assert result[2]["index"] == 2
|
||||
|
||||
def test_args_accumulation(self):
|
||||
"""Test that arguments are properly accumulated for same index."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"q', "id": "call_1", "index": 0},
|
||||
{"name": "web_search", "args": 'uery": "test"}', "id": "call_1", "index": 0},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
# Sanitize removes json encoding, so just check it's accumulated
|
||||
assert len(result[0]["args"]) > 0
|
||||
|
||||
def test_preserve_tool_id(self):
|
||||
"""Test that tool IDs are preserved correctly."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_abc123", "index": 0},
|
||||
{"name": "web_search", "args": '{}', "id": "call_xyz789", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert result[0]["id"] == "call_abc123"
|
||||
assert result[1]["id"] == "call_xyz789"
|
||||
|
||||
def test_multiple_indices_detected(self):
|
||||
"""Test that multiple indices are properly detected and logged."""
|
||||
chunks = [
|
||||
{"name": "tool_a", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "tool_b", "args": '{}', "id": "call_2", "index": 1},
|
||||
{"name": "tool_c", "args": '{}', "id": "call_3", "index": 2},
|
||||
]
|
||||
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Should have debug logs for multiple indices
|
||||
debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list]
|
||||
# Check if any debug call mentions multiple indices
|
||||
multiple_indices_mentioned = any(
|
||||
"Multiple indices" in call for call in debug_calls
|
||||
)
|
||||
assert multiple_indices_mentioned or len(result) == 3
|
||||
|
||||
|
||||
class TestValidateToolCallChunks:
|
||||
"""Test cases for _validate_tool_call_chunks function."""
|
||||
|
||||
def test_validate_empty_chunks(self):
|
||||
"""Test validation of empty chunks."""
|
||||
# Should not raise any exception
|
||||
_validate_tool_call_chunks([])
|
||||
|
||||
def test_validate_logs_chunk_info(self):
|
||||
"""Test that validation logs chunk information."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_1", "index": 0},
|
||||
]
|
||||
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
_validate_tool_call_chunks(chunks)
|
||||
|
||||
# Should have logged debug info
|
||||
assert mock_logger.debug.called
|
||||
|
||||
def test_validate_detects_multiple_indices(self):
|
||||
"""Test that validation detects multiple indices."""
|
||||
chunks = [
|
||||
{"name": "tool_1", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "tool_2", "args": '{}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
_validate_tool_call_chunks(chunks)
|
||||
|
||||
# Should have logged about multiple indices
|
||||
debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list]
|
||||
multiple_indices_mentioned = any(
|
||||
"Multiple indices" in call for call in debug_calls
|
||||
)
|
||||
assert multiple_indices_mentioned
|
||||
|
||||
|
||||
class TestRealWorldScenarios:
|
||||
"""Test cases for real-world scenarios from issue #523."""
|
||||
|
||||
def test_issue_523_scenario_consecutive_web_search(self):
|
||||
"""
|
||||
Replicate issue #523: Consecutive web_search calls.
|
||||
Previously would result in "web_searchweb_search" error.
|
||||
"""
|
||||
# Simulate streaming chunks from two consecutive web_search calls
|
||||
chunks = [
|
||||
# First web_search call (index 0)
|
||||
{"name": "web_", "args": '{"query', "id": "call_1", "index": 0},
|
||||
{"name": "search", "args": '": "first query"}', "id": "call_1", "index": 0},
|
||||
# Second web_search call (index 1)
|
||||
{"name": "web_", "args": '{"query', "id": "call_2", "index": 1},
|
||||
{"name": "search", "args": '": "second query"}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Should have 2 tool calls, not concatenated names
|
||||
assert len(result) >= 1 # At minimum should process without error
|
||||
|
||||
# Extract tool names from result
|
||||
tool_names = [chunk.get("name") for chunk in result]
|
||||
|
||||
# Verify "web_searchweb_search" error doesn't occur
|
||||
assert "web_searchweb_search" not in tool_names
|
||||
|
||||
# Both calls should have web_search (or parts of it)
|
||||
concatenated_names = "".join(tool_names)
|
||||
assert "web_search" in concatenated_names or "web_" in concatenated_names
|
||||
|
||||
def test_mixed_tools_consecutive_calls(self):
|
||||
"""Test realistic scenario with mixed tools in sequence."""
|
||||
chunks = [
|
||||
# web_search call
|
||||
{"name": "web_search", "args": '{"query": "python"}', "id": "1", "index": 0},
|
||||
# crawl_tool call
|
||||
{"name": "crawl_tool", "args": '{"url": "http://example.com"}', "id": "2", "index": 1},
|
||||
# Another web_search
|
||||
{"name": "web_search", "args": '{"query": "rust"}', "id": "3", "index": 2},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 3
|
||||
tool_names = [chunk.get("name") for chunk in result]
|
||||
|
||||
# No concatenation should occur
|
||||
assert "web_searchcrawl_tool" not in tool_names
|
||||
assert "crawl_toolweb_search" not in tool_names
|
||||
|
||||
def test_long_sequence_tool_calls(self):
|
||||
"""Test a long sequence of tool calls."""
|
||||
chunks = []
|
||||
for i in range(10):
|
||||
tool_name = "web_search" if i % 2 == 0 else "crawl_tool"
|
||||
chunks.append({
|
||||
"name": tool_name,
|
||||
"args": '{"query": "test"}' if tool_name == "web_search" else '{"url": "http://example.com"}',
|
||||
"id": f"call_{i}",
|
||||
"index": i
|
||||
})
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Should process all 10 tool calls
|
||||
assert len(result) == 10
|
||||
|
||||
# Verify each tool call has correct name (not concatenated with other tool names)
|
||||
for i, chunk in enumerate(result):
|
||||
expected_name = "web_search" if i % 2 == 0 else "crawl_tool"
|
||||
actual_name = chunk.get("name", "")
|
||||
|
||||
# The actual name should be the expected name, not concatenated
|
||||
assert actual_name == expected_name, (
|
||||
f"Tool call {i} has name '{actual_name}', expected '{expected_name}'. "
|
||||
f"This indicates concatenation with adjacent tool call."
|
||||
)
|
||||
|
||||
# Verify IDs are correct
|
||||
assert chunk.get("id") == f"call_{i}"
|
||||
assert chunk.get("index") == i
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user