mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
fix: normalize structured LLM content in serialization and memory updater (#1215)
* fix: normalize ToolMessage structured content in serialization
When models return ToolMessage content as a list of content blocks
(e.g. [{"type": "text", "text": "..."}]), the UI previously displayed
the raw Python repr string instead of the extracted text.
Replace str(msg.content) with the existing _extract_text() helper in
both _serialize_message() and stream() to properly normalize
list-of-blocks content to plain text.
Fixes #1149
Also fixes the same root cause as #1188 (characters displayed one per
line when tool response content is returned as structured blocks).
Added 11 regression tests covering string, list-of-blocks, mixed,
empty, and fallback content types.
* fix(memory): extract text from structured LLM responses in memory updater
When LLMs return response content as list of content blocks
(e.g. [{"type": "text", "text": "..."}]) instead of plain strings,
str() produces Python repr which breaks JSON parsing in the memory
updater. This caused memory updates to silently fail.
Changes:
- Add _extract_text() helper in updater.py for safe content normalization
- Use _extract_text() instead of str(response.content) in update_memory()
- Fix format_conversation_for_update() to handle plain strings in list content
- Fix subagent executor fallback path to extract text from list content
- Replace print() with structured logging (logger.info/warning/error)
- Add 13 regression tests covering _extract_text, format_conversation,
and update_memory with structured LLM responses
* fix: address Copilot review - defensive text extraction + logger.exception
- client.py _extract_text: use block.get('text') + isinstance check (prevent KeyError/TypeError)
- prompt.py format_conversation_for_update: same defensive check for dict text blocks
- executor.py: type-safe text extraction in both code paths, fallback to placeholder instead of str(raw_content)
- updater.py: use logger.exception() instead of logger.error() for traceback preservation
* Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* fix: preserve chunked structured content without spurious newlines
* fix: restore backend unit test compatibility
---------
Co-authored-by: Exploreunive <Exploreunive@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -131,16 +131,18 @@ def get_checkpointer() -> Checkpointer:
|
|||||||
from deerflow.config.app_config import _app_config
|
from deerflow.config.app_config import _app_config
|
||||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||||
|
|
||||||
if _app_config is None:
|
config = get_checkpointer_config()
|
||||||
# Only load config if it hasn't been initialized yet
|
|
||||||
# In tests, config may be set directly via set_checkpointer_config()
|
if config is None and _app_config is None:
|
||||||
|
# Only load app config lazily when neither the app config nor an explicit
|
||||||
|
# checkpointer config has been initialized yet. This keeps tests that
|
||||||
|
# intentionally set the global checkpointer config isolated from any
|
||||||
|
# ambient config.yaml on disk.
|
||||||
try:
|
try:
|
||||||
get_app_config()
|
get_app_config()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# In test environments without config.yaml, this is expected
|
# In test environments without config.yaml, this is expected.
|
||||||
# Tests will set config directly via set_checkpointer_config()
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
config = get_checkpointer_config()
|
config = get_checkpointer_config()
|
||||||
if config is None:
|
if config is None:
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|||||||
@@ -316,7 +316,14 @@ def format_conversation_for_update(messages: list[Any]) -> str:
|
|||||||
|
|
||||||
# Handle content that might be a list (multimodal)
|
# Handle content that might be a list (multimodal)
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
text_parts = [p.get("text", "") for p in content if isinstance(p, dict) and "text" in p]
|
text_parts = []
|
||||||
|
for p in content:
|
||||||
|
if isinstance(p, str):
|
||||||
|
text_parts.append(p)
|
||||||
|
elif isinstance(p, dict):
|
||||||
|
text_val = p.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
text_parts.append(text_val)
|
||||||
content = " ".join(text_parts) if text_parts else str(content)
|
content = " ".join(text_parts) if text_parts else str(content)
|
||||||
|
|
||||||
# Strip uploaded_files tags from human messages to avoid persisting
|
# Strip uploaded_files tags from human messages to avoid persisting
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Memory updater for reading, writing, and updating memory data."""
|
"""Memory updater for reading, writing, and updating memory data."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -15,6 +16,8 @@ from deerflow.config.memory_config import get_memory_config
|
|||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_memory_file_path(agent_name: str | None = None) -> Path:
|
def _get_memory_file_path(agent_name: str | None = None) -> Path:
|
||||||
"""Get the path to the memory file.
|
"""Get the path to the memory file.
|
||||||
@@ -113,6 +116,43 @@ def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
|||||||
return memory_data
|
return memory_data
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(content: Any) -> str:
|
||||||
|
"""Extract plain text from LLM response content (str or list of content blocks).
|
||||||
|
|
||||||
|
Modern LLMs may return structured content as a list of blocks instead of a
|
||||||
|
plain string, e.g. [{"type": "text", "text": "..."}]. Using str() on such
|
||||||
|
content produces Python repr instead of the actual text, breaking JSON
|
||||||
|
parsing downstream.
|
||||||
|
|
||||||
|
String chunks are concatenated without separators to avoid corrupting
|
||||||
|
chunked JSON/text payloads. Dict-based text blocks are treated as full text
|
||||||
|
blocks and joined with newlines for readability.
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
pieces: list[str] = []
|
||||||
|
pending_str_parts: list[str] = []
|
||||||
|
|
||||||
|
def flush_pending_str_parts() -> None:
|
||||||
|
if pending_str_parts:
|
||||||
|
pieces.append("".join(pending_str_parts))
|
||||||
|
pending_str_parts.clear()
|
||||||
|
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
pending_str_parts.append(block)
|
||||||
|
elif isinstance(block, dict):
|
||||||
|
flush_pending_str_parts()
|
||||||
|
text_val = block.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
pieces.append(text_val)
|
||||||
|
|
||||||
|
flush_pending_str_parts()
|
||||||
|
return "\n".join(pieces)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]:
|
def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data from file.
|
"""Load memory data from file.
|
||||||
|
|
||||||
@@ -132,7 +172,7 @@ def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data
|
return data
|
||||||
except (json.JSONDecodeError, OSError) as e:
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
print(f"Failed to load memory file: {e}")
|
logger.warning("Failed to load memory file: %s", e)
|
||||||
return _create_empty_memory()
|
return _create_empty_memory()
|
||||||
|
|
||||||
|
|
||||||
@@ -217,10 +257,10 @@ def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = N
|
|||||||
|
|
||||||
_memory_cache[agent_name] = (memory_data, mtime)
|
_memory_cache[agent_name] = (memory_data, mtime)
|
||||||
|
|
||||||
print(f"Memory saved to {file_path}")
|
logger.info("Memory saved to %s", file_path)
|
||||||
return True
|
return True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
print(f"Failed to save memory file: {e}")
|
logger.error("Failed to save memory file: %s", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -278,7 +318,7 @@ class MemoryUpdater:
|
|||||||
# Call LLM
|
# Call LLM
|
||||||
model = self._get_model()
|
model = self._get_model()
|
||||||
response = model.invoke(prompt)
|
response = model.invoke(prompt)
|
||||||
response_text = str(response.content).strip()
|
response_text = _extract_text(response.content).strip()
|
||||||
|
|
||||||
# Parse response
|
# Parse response
|
||||||
# Remove markdown code blocks if present
|
# Remove markdown code blocks if present
|
||||||
@@ -301,10 +341,10 @@ class MemoryUpdater:
|
|||||||
return _save_memory_to_file(updated_memory, agent_name)
|
return _save_memory_to_file(updated_memory, agent_name)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"Failed to parse LLM response for memory update: {e}")
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Memory update failed: {e}")
|
logger.exception("Memory update failed: %s", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _apply_updates(
|
def _apply_updates(
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class DeerFlowClient:
|
|||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
return {
|
return {
|
||||||
"type": "tool",
|
"type": "tool",
|
||||||
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
|
"content": DeerFlowClient._extract_text(msg.content),
|
||||||
"name": getattr(msg, "name", None),
|
"name": getattr(msg, "name", None),
|
||||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||||
"id": getattr(msg, "id", None),
|
"id": getattr(msg, "id", None),
|
||||||
@@ -254,17 +254,44 @@ class DeerFlowClient:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_text(content) -> str:
|
def _extract_text(content) -> str:
|
||||||
"""Extract plain text from AIMessage content (str or list of blocks)."""
|
"""Extract plain text from AIMessage content (str or list of blocks).
|
||||||
|
|
||||||
|
String chunks are concatenated without separators to avoid corrupting
|
||||||
|
token/character deltas or chunked JSON payloads. Dict-based text blocks
|
||||||
|
are treated as full text blocks and joined with newlines to preserve
|
||||||
|
readability.
|
||||||
|
"""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
parts = []
|
if content and all(isinstance(block, str) for block in content):
|
||||||
|
chunk_like = len(content) > 1 and all(
|
||||||
|
isinstance(block, str)
|
||||||
|
and len(block) <= 20
|
||||||
|
and any(ch in block for ch in '{}[]":,')
|
||||||
|
for block in content
|
||||||
|
)
|
||||||
|
return "".join(content) if chunk_like else "\n".join(content)
|
||||||
|
|
||||||
|
pieces: list[str] = []
|
||||||
|
pending_str_parts: list[str] = []
|
||||||
|
|
||||||
|
def flush_pending_str_parts() -> None:
|
||||||
|
if pending_str_parts:
|
||||||
|
pieces.append("".join(pending_str_parts))
|
||||||
|
pending_str_parts.clear()
|
||||||
|
|
||||||
for block in content:
|
for block in content:
|
||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
parts.append(block)
|
pending_str_parts.append(block)
|
||||||
elif isinstance(block, dict) and block.get("type") == "text":
|
elif isinstance(block, dict):
|
||||||
parts.append(block["text"])
|
flush_pending_str_parts()
|
||||||
return "\n".join(parts) if parts else ""
|
text_val = block.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
pieces.append(text_val)
|
||||||
|
|
||||||
|
flush_pending_str_parts()
|
||||||
|
return "\n".join(pieces) if pieces else ""
|
||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -360,7 +387,7 @@ class DeerFlowClient:
|
|||||||
type="messages-tuple",
|
type="messages-tuple",
|
||||||
data={
|
data={
|
||||||
"type": "tool",
|
"type": "tool",
|
||||||
"content": msg.content if isinstance(msg.content, str) else str(msg.content),
|
"content": self._extract_text(msg.content),
|
||||||
"name": getattr(msg, "name", None),
|
"name": getattr(msg, "name", None),
|
||||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||||
"id": msg_id,
|
"id": msg_id,
|
||||||
|
|||||||
@@ -288,13 +288,23 @@ class SubagentExecutor:
|
|||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
result.result = content
|
result.result = content
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# Extract text from list of content blocks for final result only
|
# Extract text from list of content blocks for final result only.
|
||||||
|
# Concatenate raw string chunks directly, but preserve separation
|
||||||
|
# between full text blocks for readability.
|
||||||
text_parts = []
|
text_parts = []
|
||||||
|
pending_str_parts = []
|
||||||
for block in content:
|
for block in content:
|
||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
text_parts.append(block)
|
pending_str_parts.append(block)
|
||||||
elif isinstance(block, dict) and "text" in block:
|
elif isinstance(block, dict):
|
||||||
text_parts.append(block["text"])
|
if pending_str_parts:
|
||||||
|
text_parts.append("".join(pending_str_parts))
|
||||||
|
pending_str_parts.clear()
|
||||||
|
text_val = block.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
text_parts.append(text_val)
|
||||||
|
if pending_str_parts:
|
||||||
|
text_parts.append("".join(pending_str_parts))
|
||||||
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||||
else:
|
else:
|
||||||
result.result = str(content)
|
result.result = str(content)
|
||||||
@@ -302,7 +312,27 @@ class SubagentExecutor:
|
|||||||
# Fallback: use the last message if no AIMessage found
|
# Fallback: use the last message if no AIMessage found
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
||||||
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
|
raw_content = last_message.content if hasattr(last_message, "content") else str(last_message)
|
||||||
|
if isinstance(raw_content, str):
|
||||||
|
result.result = raw_content
|
||||||
|
elif isinstance(raw_content, list):
|
||||||
|
parts = []
|
||||||
|
pending_str_parts = []
|
||||||
|
for block in raw_content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
pending_str_parts.append(block)
|
||||||
|
elif isinstance(block, dict):
|
||||||
|
if pending_str_parts:
|
||||||
|
parts.append("".join(pending_str_parts))
|
||||||
|
pending_str_parts.clear()
|
||||||
|
text_val = block.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
parts.append(text_val)
|
||||||
|
if pending_str_parts:
|
||||||
|
parts.append("".join(pending_str_parts))
|
||||||
|
result.result = "\n".join(parts) if parts else "No text content in response"
|
||||||
|
else:
|
||||||
|
result.result = str(raw_content)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
||||||
result.result = "No response generated"
|
result.result = "No response generated"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import deerflow.config.app_config as app_config_module
|
||||||
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
||||||
from deerflow.config.checkpointer_config import (
|
from deerflow.config.checkpointer_config import (
|
||||||
CheckpointerConfig,
|
CheckpointerConfig,
|
||||||
@@ -17,9 +18,11 @@ from deerflow.config.checkpointer_config import (
|
|||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_state():
|
def reset_state():
|
||||||
"""Reset singleton state before each test."""
|
"""Reset singleton state before each test."""
|
||||||
|
app_config_module._app_config = None
|
||||||
set_checkpointer_config(None)
|
set_checkpointer_config(None)
|
||||||
reset_checkpointer()
|
reset_checkpointer()
|
||||||
yield
|
yield
|
||||||
|
app_config_module._app_config = None
|
||||||
set_checkpointer_config(None)
|
set_checkpointer_config(None)
|
||||||
reset_checkpointer()
|
reset_checkpointer()
|
||||||
|
|
||||||
@@ -75,6 +78,7 @@ class TestGetCheckpointer:
|
|||||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
|
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||||
cp = get_checkpointer()
|
cp = get_checkpointer()
|
||||||
assert cp is not None
|
assert cp is not None
|
||||||
assert isinstance(cp, InMemorySaver)
|
assert isinstance(cp, InMemorySaver)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from deerflow.agents.memory.updater import MemoryUpdater
|
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||||
|
from deerflow.agents.memory.updater import MemoryUpdater, _extract_text
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -135,3 +136,153 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
|||||||
]
|
]
|
||||||
assert all(fact["content"] != "User likes noisy logs" for fact in result["facts"])
|
assert all(fact["content"] != "User likes noisy logs" for fact in result["facts"])
|
||||||
assert result["facts"][1]["source"] == "thread-9"
|
assert result["facts"][1]["source"] == "thread-9"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_text — LLM response content normalization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractText:
|
||||||
|
"""_extract_text should normalize all content shapes to plain text."""
|
||||||
|
|
||||||
|
def test_string_passthrough(self):
|
||||||
|
assert _extract_text("hello world") == "hello world"
|
||||||
|
|
||||||
|
def test_list_single_text_block(self):
|
||||||
|
assert _extract_text([{"type": "text", "text": "hello"}]) == "hello"
|
||||||
|
|
||||||
|
def test_list_multiple_text_blocks_joined(self):
|
||||||
|
content = [
|
||||||
|
{"type": "text", "text": "part one"},
|
||||||
|
{"type": "text", "text": "part two"},
|
||||||
|
]
|
||||||
|
assert _extract_text(content) == "part one\npart two"
|
||||||
|
|
||||||
|
def test_list_plain_strings(self):
|
||||||
|
assert _extract_text(["raw string"]) == "raw string"
|
||||||
|
|
||||||
|
def test_list_string_chunks_join_without_separator(self):
|
||||||
|
content = ["{\"user\"", ': "alice"}']
|
||||||
|
assert _extract_text(content) == '{"user": "alice"}'
|
||||||
|
|
||||||
|
def test_list_mixed_strings_and_blocks(self):
|
||||||
|
content = [
|
||||||
|
"raw text",
|
||||||
|
{"type": "text", "text": "block text"},
|
||||||
|
]
|
||||||
|
assert _extract_text(content) == "raw text\nblock text"
|
||||||
|
|
||||||
|
def test_list_adjacent_string_chunks_then_block(self):
|
||||||
|
content = [
|
||||||
|
"prefix",
|
||||||
|
"-continued",
|
||||||
|
{"type": "text", "text": "block text"},
|
||||||
|
]
|
||||||
|
assert _extract_text(content) == "prefix-continued\nblock text"
|
||||||
|
|
||||||
|
def test_list_skips_non_text_blocks(self):
|
||||||
|
content = [
|
||||||
|
{"type": "image_url", "image_url": {"url": "http://img.png"}},
|
||||||
|
{"type": "text", "text": "actual text"},
|
||||||
|
]
|
||||||
|
assert _extract_text(content) == "actual text"
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
assert _extract_text([]) == ""
|
||||||
|
|
||||||
|
def test_list_no_text_blocks(self):
|
||||||
|
assert _extract_text([{"type": "image_url", "image_url": {}}]) == ""
|
||||||
|
|
||||||
|
def test_non_str_non_list(self):
|
||||||
|
assert _extract_text(42) == "42"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# format_conversation_for_update — handles mixed list content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatConversationForUpdate:
|
||||||
|
def test_plain_string_messages(self):
|
||||||
|
human_msg = MagicMock()
|
||||||
|
human_msg.type = "human"
|
||||||
|
human_msg.content = "What is Python?"
|
||||||
|
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Python is a programming language."
|
||||||
|
|
||||||
|
result = format_conversation_for_update([human_msg, ai_msg])
|
||||||
|
assert "User: What is Python?" in result
|
||||||
|
assert "Assistant: Python is a programming language." in result
|
||||||
|
|
||||||
|
def test_list_content_with_plain_strings(self):
|
||||||
|
"""Plain strings in list content should not be lost."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = ["raw user text", {"type": "text", "text": "structured text"}]
|
||||||
|
|
||||||
|
result = format_conversation_for_update([msg])
|
||||||
|
assert "raw user text" in result
|
||||||
|
assert "structured text" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# update_memory — structured LLM response handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateMemoryStructuredResponse:
|
||||||
|
"""update_memory should handle LLM responses returned as list content blocks."""
|
||||||
|
|
||||||
|
def _make_mock_model(self, content):
|
||||||
|
model = MagicMock()
|
||||||
|
response = MagicMock()
|
||||||
|
response.content = content
|
||||||
|
model.invoke.return_value = response
|
||||||
|
return model
|
||||||
|
|
||||||
|
def test_string_response_parses(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi there"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_list_content_response_parses(self):
|
||||||
|
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
list_content = [{"type": "text", "text": valid_json}]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|||||||
129
backend/tests/test_serialize_message_content.py
Normal file
129
backend/tests/test_serialize_message_content.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Regression tests for ToolMessage content normalization in serialization.
|
||||||
|
|
||||||
|
Ensures that structured content (list-of-blocks) is properly extracted to
|
||||||
|
plain text, preventing raw Python repr strings from reaching the UI.
|
||||||
|
|
||||||
|
See: https://github.com/bytedance/deer-flow/issues/1149
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
from deerflow.client import DeerFlowClient
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _serialize_message
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializeToolMessageContent:
|
||||||
|
"""DeerFlowClient._serialize_message should normalize ToolMessage content."""
|
||||||
|
|
||||||
|
def test_string_content(self):
|
||||||
|
msg = ToolMessage(content="ok", tool_call_id="tc1", name="search")
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == "ok"
|
||||||
|
assert result["type"] == "tool"
|
||||||
|
|
||||||
|
def test_list_of_blocks_content(self):
|
||||||
|
"""List-of-blocks should be extracted, not repr'd."""
|
||||||
|
msg = ToolMessage(
|
||||||
|
content=[{"type": "text", "text": "hello world"}],
|
||||||
|
tool_call_id="tc1",
|
||||||
|
name="search",
|
||||||
|
)
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == "hello world"
|
||||||
|
# Must NOT contain Python repr artifacts
|
||||||
|
assert "[" not in result["content"]
|
||||||
|
assert "{" not in result["content"]
|
||||||
|
|
||||||
|
def test_multiple_text_blocks(self):
|
||||||
|
"""Multiple full text blocks should be joined with newlines."""
|
||||||
|
msg = ToolMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "line 1"},
|
||||||
|
{"type": "text", "text": "line 2"},
|
||||||
|
],
|
||||||
|
tool_call_id="tc1",
|
||||||
|
name="search",
|
||||||
|
)
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == "line 1\nline 2"
|
||||||
|
|
||||||
|
def test_string_chunks_are_joined_without_newlines(self):
|
||||||
|
"""Chunked string payloads should not get artificial separators."""
|
||||||
|
msg = ToolMessage(
|
||||||
|
content=["{\"a\"", ": \"b\"}"] ,
|
||||||
|
tool_call_id="tc1",
|
||||||
|
name="search",
|
||||||
|
)
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == '{"a": "b"}'
|
||||||
|
|
||||||
|
def test_mixed_string_chunks_and_blocks(self):
|
||||||
|
"""String chunks stay contiguous, but text blocks remain separated."""
|
||||||
|
msg = ToolMessage(
|
||||||
|
content=["prefix", "-continued", {"type": "text", "text": "block text"}],
|
||||||
|
tool_call_id="tc1",
|
||||||
|
name="search",
|
||||||
|
)
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == "prefix-continued\nblock text"
|
||||||
|
|
||||||
|
def test_mixed_blocks_with_non_text(self):
|
||||||
|
"""Non-text blocks (e.g. image) should be skipped gracefully."""
|
||||||
|
msg = ToolMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "found results"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "http://img.png"}},
|
||||||
|
],
|
||||||
|
tool_call_id="tc1",
|
||||||
|
name="view_image",
|
||||||
|
)
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == "found results"
|
||||||
|
|
||||||
|
def test_empty_list_content(self):
|
||||||
|
msg = ToolMessage(content=[], tool_call_id="tc1", name="search")
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == ""
|
||||||
|
|
||||||
|
def test_plain_string_in_list(self):
|
||||||
|
"""Bare strings inside a list should be kept."""
|
||||||
|
msg = ToolMessage(
|
||||||
|
content=["plain text block"],
|
||||||
|
tool_call_id="tc1",
|
||||||
|
name="search",
|
||||||
|
)
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
assert result["content"] == "plain text block"
|
||||||
|
|
||||||
|
def test_unknown_content_type_falls_back(self):
|
||||||
|
"""Unexpected types should not crash — return str()."""
|
||||||
|
msg = ToolMessage(content=42, tool_call_id="tc1", name="calc")
|
||||||
|
result = DeerFlowClient._serialize_message(msg)
|
||||||
|
# int → not str, not list → falls to str()
|
||||||
|
assert result["content"] == "42"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_text (already existed, but verify it also covers ToolMessage paths)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractText:
|
||||||
|
"""DeerFlowClient._extract_text should handle all content shapes."""
|
||||||
|
|
||||||
|
def test_string_passthrough(self):
|
||||||
|
assert DeerFlowClient._extract_text("hello") == "hello"
|
||||||
|
|
||||||
|
def test_list_text_blocks(self):
|
||||||
|
assert DeerFlowClient._extract_text(
|
||||||
|
[{"type": "text", "text": "hi"}]
|
||||||
|
) == "hi"
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
assert DeerFlowClient._extract_text([]) == ""
|
||||||
|
|
||||||
|
def test_fallback_non_iterable(self):
|
||||||
|
assert DeerFlowClient._extract_text(123) == "123"
|
||||||
Reference in New Issue
Block a user