mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-23 22:24:46 +08:00
feat: add view_image tool and optimize web fetch tools
Add image viewing capability for vision-enabled models with ViewImageMiddleware and view_image_tool. Limit web_fetch tool output to 4096 characters to prevent excessive content. Update model config to support vision capability flag. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from src.agents.middlewares.clarification_middleware import ClarificationMiddlew
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.config.summarization_config import get_summarization_config
|
||||
from src.models import create_chat_model
|
||||
@@ -174,6 +175,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
|
||||
# SummarizationMiddleware should be early to reduce context before other processing
|
||||
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
@@ -197,7 +199,24 @@ def _build_middlewares(config: RunnableConfig):
|
||||
if todo_list_middleware is not None:
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
middlewares.extend([TitleMiddleware(), ClarificationMiddleware()])
|
||||
# Add TitleMiddleware
|
||||
middlewares.append(TitleMiddleware())
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision
|
||||
model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
||||
from src.config import get_app_config
|
||||
|
||||
app_config = get_app_config()
|
||||
# If no model_name specified, use the first model (default)
|
||||
if model_name is None and app_config.models:
|
||||
model_name = app_config.models[0].name
|
||||
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
|
||||
221
backend/src/agents/middlewares/view_image_middleware.py
Normal file
221
backend/src/agents/middlewares/view_image_middleware.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Middleware for injecting image details into conversation before LLM call."""
|
||||
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.thread_state import ViewedImageData
|
||||
|
||||
|
||||
class ViewImageMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
viewed_images: NotRequired[dict[str, ViewedImageData] | None]
|
||||
|
||||
|
||||
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
||||
"""Injects image details as a human message before LLM calls when view_image tools have completed.
|
||||
|
||||
This middleware:
|
||||
1. Runs before each LLM call
|
||||
2. Checks if the last assistant message contains view_image tool calls
|
||||
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
|
||||
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
|
||||
5. Adds the message to state so the LLM can see and analyze the images
|
||||
|
||||
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
|
||||
without requiring explicit user prompts to describe the images.
|
||||
"""
|
||||
|
||||
state_schema = ViewImageMiddlewareState
|
||||
|
||||
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
|
||||
"""Get the last assistant message from the message list.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Last AIMessage or None if not found
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def _has_view_image_tool(self, message: AIMessage) -> bool:
|
||||
"""Check if the assistant message contains view_image tool calls.
|
||||
|
||||
Args:
|
||||
message: Assistant message to check
|
||||
|
||||
Returns:
|
||||
True if message contains view_image tool calls
|
||||
"""
|
||||
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
|
||||
|
||||
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
|
||||
"""Check if all tool calls in the assistant message have been completed.
|
||||
|
||||
Args:
|
||||
messages: List of all messages
|
||||
assistant_msg: The assistant message containing tool calls
|
||||
|
||||
Returns:
|
||||
True if all tool calls have corresponding ToolMessages
|
||||
"""
|
||||
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
|
||||
return False
|
||||
|
||||
# Get all tool call IDs from the assistant message
|
||||
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
|
||||
|
||||
# Find the index of the assistant message
|
||||
try:
|
||||
assistant_idx = messages.index(assistant_msg)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Get all ToolMessages after the assistant message
|
||||
completed_tool_ids = set()
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id:
|
||||
completed_tool_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if all tool calls have been completed
|
||||
return tool_call_ids.issubset(completed_tool_ids)
|
||||
|
||||
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
|
||||
"""Create a formatted message with all viewed image details.
|
||||
|
||||
Args:
|
||||
state: Current state containing viewed_images
|
||||
|
||||
Returns:
|
||||
List of content blocks (text and images) for the HumanMessage
|
||||
"""
|
||||
viewed_images = state.get("viewed_images", {})
|
||||
if not viewed_images:
|
||||
return ["No images have been viewed."]
|
||||
|
||||
# Build the message with image information
|
||||
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
|
||||
|
||||
for image_path, image_data in viewed_images.items():
|
||||
mime_type = image_data.get("mime_type", "unknown")
|
||||
base64_data = image_data.get("base64", "")
|
||||
|
||||
# Add text description
|
||||
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
|
||||
|
||||
# Add the actual image data so LLM can "see" it
|
||||
if base64_data:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
|
||||
"""Determine if we should inject an image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
True if we should inject the message
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# Get the last assistant message
|
||||
last_assistant_msg = self._get_last_assistant_message(messages)
|
||||
if not last_assistant_msg:
|
||||
return False
|
||||
|
||||
# Check if it has view_image tool calls
|
||||
if not self._has_view_image_tool(last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if all tools have been completed
|
||||
if not self._all_tools_completed(messages, last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if we've already added an image details message
|
||||
# Look for a human message after the last assistant message that contains image details
|
||||
assistant_idx = messages.index(last_assistant_msg)
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content_str = str(msg.content)
|
||||
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
|
||||
# Already added, don't add again
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
|
||||
"""Internal helper to inject image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
if not self._should_inject_image_message(state):
|
||||
return None
|
||||
|
||||
# Create the image details message with text and image content
|
||||
image_content = self._create_image_details_message(state)
|
||||
|
||||
# Create a new human message with mixed content (text + images)
|
||||
human_msg = HumanMessage(content=image_content)
|
||||
|
||||
print("[ViewImageMiddleware] Injecting image details message with images before LLM call")
|
||||
|
||||
# Return state update with the new message
|
||||
return {"messages": [human_msg]}
|
||||
|
||||
@override
|
||||
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (sync version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
|
||||
@override
|
||||
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (async version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import NotRequired, TypedDict
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents import AgentState
|
||||
|
||||
@@ -13,10 +13,43 @@ class ThreadDataState(TypedDict):
|
||||
outputs_path: NotRequired[str | None]
|
||||
|
||||
|
||||
class ViewedImageData(TypedDict):
|
||||
base64: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||
if existing is None:
|
||||
return new or []
|
||||
if new is None:
|
||||
return existing
|
||||
# Use dict.fromkeys to deduplicate while preserving order
|
||||
return list(dict.fromkeys(existing + new))
|
||||
|
||||
|
||||
def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]:
|
||||
"""Reducer for viewed_images dict - merges image dictionaries.
|
||||
|
||||
Special case: If new is an empty dict {}, it clears the existing images.
|
||||
This allows middlewares to clear the viewed_images state after processing.
|
||||
"""
|
||||
if existing is None:
|
||||
return new or {}
|
||||
if new is None:
|
||||
return existing
|
||||
# Special case: empty dict means clear all viewed images
|
||||
if len(new) == 0:
|
||||
return {}
|
||||
# Merge dictionaries, new values override existing ones for same keys
|
||||
return {**existing, **new}
|
||||
|
||||
|
||||
class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
artifacts: NotRequired[list[str] | None]
|
||||
artifacts: Annotated[list[str], merge_artifacts]
|
||||
todos: NotRequired[list | None]
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
|
||||
|
||||
Reference in New Issue
Block a user