mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat: add view_image tool and optimize web fetch tools
Add image viewing capability for vision-enabled models with ViewImageMiddleware and view_image_tool. Limit web_fetch tool output to 4096 characters to prevent excessive content. Update model config to support vision capability flag. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from src.agents.middlewares.clarification_middleware import ClarificationMiddlew
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.config.summarization_config import get_summarization_config
|
||||
from src.models import create_chat_model
|
||||
@@ -174,6 +175,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
|
||||
# SummarizationMiddleware should be early to reduce context before other processing
|
||||
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
@@ -197,7 +199,24 @@ def _build_middlewares(config: RunnableConfig):
|
||||
if todo_list_middleware is not None:
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
middlewares.extend([TitleMiddleware(), ClarificationMiddleware()])
|
||||
# Add TitleMiddleware
|
||||
middlewares.append(TitleMiddleware())
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision
|
||||
model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
||||
from src.config import get_app_config
|
||||
|
||||
app_config = get_app_config()
|
||||
# If no model_name specified, use the first model (default)
|
||||
if model_name is None and app_config.models:
|
||||
model_name = app_config.models[0].name
|
||||
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
|
||||
221
backend/src/agents/middlewares/view_image_middleware.py
Normal file
221
backend/src/agents/middlewares/view_image_middleware.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Middleware for injecting image details into conversation before LLM call."""
|
||||
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.thread_state import ViewedImageData
|
||||
|
||||
|
||||
class ViewImageMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
viewed_images: NotRequired[dict[str, ViewedImageData] | None]
|
||||
|
||||
|
||||
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
||||
"""Injects image details as a human message before LLM calls when view_image tools have completed.
|
||||
|
||||
This middleware:
|
||||
1. Runs before each LLM call
|
||||
2. Checks if the last assistant message contains view_image tool calls
|
||||
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
|
||||
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
|
||||
5. Adds the message to state so the LLM can see and analyze the images
|
||||
|
||||
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
|
||||
without requiring explicit user prompts to describe the images.
|
||||
"""
|
||||
|
||||
state_schema = ViewImageMiddlewareState
|
||||
|
||||
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
|
||||
"""Get the last assistant message from the message list.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Last AIMessage or None if not found
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def _has_view_image_tool(self, message: AIMessage) -> bool:
|
||||
"""Check if the assistant message contains view_image tool calls.
|
||||
|
||||
Args:
|
||||
message: Assistant message to check
|
||||
|
||||
Returns:
|
||||
True if message contains view_image tool calls
|
||||
"""
|
||||
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
|
||||
|
||||
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
|
||||
"""Check if all tool calls in the assistant message have been completed.
|
||||
|
||||
Args:
|
||||
messages: List of all messages
|
||||
assistant_msg: The assistant message containing tool calls
|
||||
|
||||
Returns:
|
||||
True if all tool calls have corresponding ToolMessages
|
||||
"""
|
||||
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
|
||||
return False
|
||||
|
||||
# Get all tool call IDs from the assistant message
|
||||
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
|
||||
|
||||
# Find the index of the assistant message
|
||||
try:
|
||||
assistant_idx = messages.index(assistant_msg)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Get all ToolMessages after the assistant message
|
||||
completed_tool_ids = set()
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id:
|
||||
completed_tool_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if all tool calls have been completed
|
||||
return tool_call_ids.issubset(completed_tool_ids)
|
||||
|
||||
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
|
||||
"""Create a formatted message with all viewed image details.
|
||||
|
||||
Args:
|
||||
state: Current state containing viewed_images
|
||||
|
||||
Returns:
|
||||
List of content blocks (text and images) for the HumanMessage
|
||||
"""
|
||||
viewed_images = state.get("viewed_images", {})
|
||||
if not viewed_images:
|
||||
return ["No images have been viewed."]
|
||||
|
||||
# Build the message with image information
|
||||
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
|
||||
|
||||
for image_path, image_data in viewed_images.items():
|
||||
mime_type = image_data.get("mime_type", "unknown")
|
||||
base64_data = image_data.get("base64", "")
|
||||
|
||||
# Add text description
|
||||
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
|
||||
|
||||
# Add the actual image data so LLM can "see" it
|
||||
if base64_data:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
|
||||
"""Determine if we should inject an image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
True if we should inject the message
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# Get the last assistant message
|
||||
last_assistant_msg = self._get_last_assistant_message(messages)
|
||||
if not last_assistant_msg:
|
||||
return False
|
||||
|
||||
# Check if it has view_image tool calls
|
||||
if not self._has_view_image_tool(last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if all tools have been completed
|
||||
if not self._all_tools_completed(messages, last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if we've already added an image details message
|
||||
# Look for a human message after the last assistant message that contains image details
|
||||
assistant_idx = messages.index(last_assistant_msg)
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content_str = str(msg.content)
|
||||
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
|
||||
# Already added, don't add again
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
|
||||
"""Internal helper to inject image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
if not self._should_inject_image_message(state):
|
||||
return None
|
||||
|
||||
# Create the image details message with text and image content
|
||||
image_content = self._create_image_details_message(state)
|
||||
|
||||
# Create a new human message with mixed content (text + images)
|
||||
human_msg = HumanMessage(content=image_content)
|
||||
|
||||
print("[ViewImageMiddleware] Injecting image details message with images before LLM call")
|
||||
|
||||
# Return state update with the new message
|
||||
return {"messages": [human_msg]}
|
||||
|
||||
@override
|
||||
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (sync version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
|
||||
@override
|
||||
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (async version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import NotRequired, TypedDict
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents import AgentState
|
||||
|
||||
@@ -13,10 +13,43 @@ class ThreadDataState(TypedDict):
|
||||
outputs_path: NotRequired[str | None]
|
||||
|
||||
|
||||
class ViewedImageData(TypedDict):
|
||||
base64: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||
if existing is None:
|
||||
return new or []
|
||||
if new is None:
|
||||
return existing
|
||||
# Use dict.fromkeys to deduplicate while preserving order
|
||||
return list(dict.fromkeys(existing + new))
|
||||
|
||||
|
||||
def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]:
|
||||
"""Reducer for viewed_images dict - merges image dictionaries.
|
||||
|
||||
Special case: If new is an empty dict {}, it clears the existing images.
|
||||
This allows middlewares to clear the viewed_images state after processing.
|
||||
"""
|
||||
if existing is None:
|
||||
return new or {}
|
||||
if new is None:
|
||||
return existing
|
||||
# Special case: empty dict means clear all viewed images
|
||||
if len(new) == 0:
|
||||
return {}
|
||||
# Merge dictionaries, new values override existing ones for same keys
|
||||
return {**existing, **new}
|
||||
|
||||
|
||||
class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
artifacts: NotRequired[list[str] | None]
|
||||
artifacts: Annotated[list[str], merge_artifacts]
|
||||
todos: NotRequired[list | None]
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
|
||||
|
||||
@@ -70,4 +70,4 @@ def web_fetch_tool(url: str) -> str:
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
return f"# {title}\n\n{markdown_content}"
|
||||
return f"# {title}\n\n{markdown_content[:4096]}"
|
||||
|
||||
@@ -25,4 +25,4 @@ def web_fetch_tool(url: str) -> str:
|
||||
timeout = config.model_extra.get("timeout")
|
||||
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
article = readability_extractor.extract_article(html_content)
|
||||
return article.to_markdown()
|
||||
return article.to_markdown()[:4096]
|
||||
|
||||
@@ -57,6 +57,6 @@ def web_fetch_tool(url: str) -> str:
|
||||
return f"Error: {res['failed_results'][0]['error']}"
|
||||
elif "results" in res and len(res["results"]) > 0:
|
||||
result = res["results"][0]
|
||||
return f"# {result['title']}\n\n{result['raw_content']}"
|
||||
return f"# {result['title']}\n\n{result['raw_content'][:4096]}"
|
||||
else:
|
||||
return "Error: No results found"
|
||||
|
||||
@@ -18,3 +18,4 @@ class ModelConfig(BaseModel):
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is enabled",
|
||||
)
|
||||
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
|
||||
|
||||
@@ -29,6 +29,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
"description",
|
||||
"supports_thinking",
|
||||
"when_thinking_enabled",
|
||||
"supports_vision",
|
||||
},
|
||||
)
|
||||
if thinking_enabled and model_config.when_thinking_enabled is not None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .clarification_tool import ask_clarification_tool
|
||||
from .present_file_tool import present_file_tool
|
||||
from .view_image_tool import view_image_tool
|
||||
|
||||
__all__ = ["present_file_tool", "ask_clarification_tool"]
|
||||
__all__ = ["present_file_tool", "ask_clarification_tool", "view_image_tool"]
|
||||
|
||||
@@ -28,15 +28,12 @@ def present_file_tool(
|
||||
|
||||
Notes:
|
||||
- You should call this tool after creating files and moving them to the `/mnt/user-data/outputs` directory.
|
||||
- IMPORTANT: Do NOT call this tool in parallel with other tools. Call it separately.
|
||||
- This tool can be safely called in parallel with other tools. State updates are handled by a reducer to prevent conflicts.
|
||||
|
||||
Args:
|
||||
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
|
||||
"""
|
||||
existing_artifacts = runtime.state.get("artifacts") or []
|
||||
# Use dict.fromkeys to deduplicate while preserving order
|
||||
new_artifacts = list(dict.fromkeys(existing_artifacts + filepaths))
|
||||
runtime.state["artifacts"] = new_artifacts
|
||||
# The merge_artifacts reducer will handle merging and deduplication
|
||||
return Command(
|
||||
update={"artifacts": new_artifacts, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]},
|
||||
update={"artifacts": filepaths, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
94
backend/src/tools/builtins/view_image_tool.py
Normal file
94
backend/src/tools/builtins/view_image_tool.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.sandbox.tools import get_thread_data, replace_virtual_path
|
||||
|
||||
|
||||
@tool("view_image", parse_docstring=True)
|
||||
def view_image_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
image_path: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Read an image file.
|
||||
|
||||
Use this tool to read an image file and make it available for display.
|
||||
|
||||
When to use the view_image tool:
|
||||
- When you need to view an image file.
|
||||
|
||||
When NOT to use the view_image tool:
|
||||
- For non-image files (use present_files instead)
|
||||
- For multiple files at once (use present_files instead)
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
|
||||
"""
|
||||
# Replace virtual path with actual path
|
||||
# /mnt/user-data/* paths are mapped to thread-specific directories
|
||||
thread_data = get_thread_data(runtime)
|
||||
actual_path = replace_virtual_path(image_path, thread_data)
|
||||
|
||||
# Validate that the path is absolute
|
||||
path = Path(actual_path)
|
||||
if not path.is_absolute():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Path must be absolute, got: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate that the file exists
|
||||
if not path.exists():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Image file not found: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate that it's a file (not a directory)
|
||||
if not path.is_file():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Path is not a file: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate image extension
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(valid_extensions)}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Detect MIME type from file extension
|
||||
mime_type, _ = mimetypes.guess_type(actual_path)
|
||||
if mime_type is None:
|
||||
# Fallback to default MIME types for common image formats
|
||||
extension_to_mime = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
mime_type = extension_to_mime.get(path.suffix.lower(), "application/octet-stream")
|
||||
|
||||
# Read image file and convert to base64
|
||||
try:
|
||||
with open(actual_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error reading image file: {str(e)}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Update viewed_images in state
|
||||
# The merge_viewed_images reducer will handle merging with existing images
|
||||
new_viewed_images = {image_path: {"base64": image_base64, "mime_type": mime_type}}
|
||||
|
||||
return Command(
|
||||
update={"viewed_images": new_viewed_images, "messages": [ToolMessage("Successfully read image", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
@@ -21,6 +21,7 @@ models:
|
||||
api_key: $OPENAI_API_KEY # Use environment variable
|
||||
max_tokens: 4096
|
||||
temperature: 0.7
|
||||
supports_vision: true # Enable vision support for view_image tool
|
||||
|
||||
# Example: Anthropic Claude model
|
||||
# - name: claude-3-5-sonnet
|
||||
@@ -29,6 +30,7 @@ models:
|
||||
# model: claude-3-5-sonnet-20241022
|
||||
# api_key: $ANTHROPIC_API_KEY
|
||||
# max_tokens: 8192
|
||||
# supports_vision: true # Enable vision support for view_image tool
|
||||
|
||||
# Example: DeepSeek model (with thinking support)
|
||||
# - name: deepseek-v3
|
||||
@@ -38,6 +40,7 @@ models:
|
||||
# api_key: $DEEPSEEK_API_KEY
|
||||
# max_tokens: 16384
|
||||
# supports_thinking: true
|
||||
# supports_vision: false # DeepSeek V3 does not support vision
|
||||
# when_thinking_enabled:
|
||||
# extra_body:
|
||||
# thinking:
|
||||
@@ -51,6 +54,7 @@ models:
|
||||
# api_base: https://ark.cn-beijing.volces.com/api/v3
|
||||
# api_key: $VOLCENGINE_API_KEY
|
||||
# supports_thinking: true
|
||||
# supports_vision: false # Check your specific model's capabilities
|
||||
# when_thinking_enabled:
|
||||
# extra_body:
|
||||
# thinking:
|
||||
@@ -65,6 +69,7 @@ models:
|
||||
# api_key: $MOONSHOT_API_KEY
|
||||
# max_tokens: 32768
|
||||
# supports_thinking: true
|
||||
# supports_vision: false # Check your specific model's capabilities
|
||||
# when_thinking_enabled:
|
||||
# extra_body:
|
||||
# thinking:
|
||||
@@ -107,6 +112,11 @@ tools:
|
||||
use: src.community.image_search.tools:image_search_tool
|
||||
max_results: 5
|
||||
|
||||
# View image tool (display local images to user)
|
||||
- name: view_image
|
||||
group: file:read
|
||||
use: src.tools.builtins:view_image_tool
|
||||
|
||||
# File operations tools
|
||||
- name: ls
|
||||
group: file:read
|
||||
|
||||
Reference in New Issue
Block a user