mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat: optimize vision tools and image handling
- Add model-aware vision tool loading based on supports_vision flag - Move view_image_tool from config to builtin tools for dynamic inclusion - Add timeout to image search to prevent hanging requests - Optimize image search results format using thumbnails - Add image validation for reference images in generation - Improve error handling with detailed messages Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -230,7 +230,7 @@ def make_lead_agent(config: RunnableConfig):
|
|||||||
print(f"thinking_enabled: {thinking_enabled}, model_name: {model_name}, is_plan_mode: {is_plan_mode}")
|
print(f"thinking_enabled: {thinking_enabled}, model_name: {model_name}, is_plan_mode: {is_plan_mode}")
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||||
tools=get_available_tools(),
|
tools=get_available_tools(model_name=model_name),
|
||||||
middleware=_build_middlewares(config),
|
middleware=_build_middlewares(config),
|
||||||
system_prompt=apply_prompt_template(),
|
system_prompt=apply_prompt_template(),
|
||||||
state_schema=ThreadState,
|
state_schema=ThreadState,
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def _search_images(
|
|||||||
logger.error("ddgs library not installed. Run: pip install ddgs")
|
logger.error("ddgs library not installed. Run: pip install ddgs")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
ddgs = DDGS()
|
ddgs = DDGS(timeout=30)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@@ -119,12 +119,8 @@ def image_search_tool(
|
|||||||
normalized_results = [
|
normalized_results = [
|
||||||
{
|
{
|
||||||
"title": r.get("title", ""),
|
"title": r.get("title", ""),
|
||||||
"image_url": r.get("image", ""),
|
"image_url": r.get("thumbnail", ""),
|
||||||
"thumbnail_url": r.get("thumbnail", ""),
|
"thumbnail_url": r.get("thumbnail", ""),
|
||||||
"source_url": r.get("url", ""),
|
|
||||||
"source": r.get("source", ""),
|
|
||||||
"width": r.get("width"),
|
|
||||||
"height": r.get("height"),
|
|
||||||
}
|
}
|
||||||
for r in results
|
for r in results
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
|
|||||||
|
|
||||||
from src.config import get_app_config
|
from src.config import get_app_config
|
||||||
from src.reflection import resolve_variable
|
from src.reflection import resolve_variable
|
||||||
from src.tools.builtins import ask_clarification_tool, present_file_tool
|
from src.tools.builtins import ask_clarification_tool, present_file_tool, view_image_tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ BUILTIN_TOOLS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_available_tools(groups: list[str] | None = None, include_mcp: bool = True) -> list[BaseTool]:
|
def get_available_tools(groups: list[str] | None = None, include_mcp: bool = True, model_name: str | None = None) -> list[BaseTool]:
|
||||||
"""Get all available tools from config.
|
"""Get all available tools from config.
|
||||||
|
|
||||||
Note: MCP tools should be initialized at application startup using
|
Note: MCP tools should be initialized at application startup using
|
||||||
@@ -23,6 +23,7 @@ def get_available_tools(groups: list[str] | None = None, include_mcp: bool = Tru
|
|||||||
Args:
|
Args:
|
||||||
groups: Optional list of tool groups to filter by.
|
groups: Optional list of tool groups to filter by.
|
||||||
include_mcp: Whether to include tools from MCP servers (default: True).
|
include_mcp: Whether to include tools from MCP servers (default: True).
|
||||||
|
model_name: Optional model name to determine if vision tools should be included.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of available tools.
|
List of available tools.
|
||||||
@@ -51,4 +52,16 @@ def get_available_tools(groups: list[str] | None = None, include_mcp: bool = Tru
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get cached MCP tools: {e}")
|
logger.error(f"Failed to get cached MCP tools: {e}")
|
||||||
|
|
||||||
return loaded_tools + BUILTIN_TOOLS + mcp_tools
|
# Conditionally add view_image_tool only if the model supports vision
|
||||||
|
builtin_tools = BUILTIN_TOOLS.copy()
|
||||||
|
|
||||||
|
# If no model_name specified, use the first model (default)
|
||||||
|
if model_name is None and config.models:
|
||||||
|
model_name = config.models[0].name
|
||||||
|
|
||||||
|
model_config = config.get_model_config(model_name) if model_name else None
|
||||||
|
if model_config is not None and model_config.supports_vision:
|
||||||
|
builtin_tools.append(view_image_tool)
|
||||||
|
logger.info(f"Including view_image_tool for model '{model_name}' (supports_vision=True)")
|
||||||
|
|
||||||
|
return loaded_tools + builtin_tools + mcp_tools
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ models:
|
|||||||
# api_key: $MOONSHOT_API_KEY
|
# api_key: $MOONSHOT_API_KEY
|
||||||
# max_tokens: 32768
|
# max_tokens: 32768
|
||||||
# supports_thinking: true
|
# supports_thinking: true
|
||||||
# supports_vision: false # Check your specific model's capabilities
|
# supports_vision: true # Check your specific model's capabilities
|
||||||
# when_thinking_enabled:
|
# when_thinking_enabled:
|
||||||
# extra_body:
|
# extra_body:
|
||||||
# thinking:
|
# thinking:
|
||||||
@@ -112,11 +112,6 @@ tools:
|
|||||||
use: src.community.image_search.tools:image_search_tool
|
use: src.community.image_search.tools:image_search_tool
|
||||||
max_results: 5
|
max_results: 5
|
||||||
|
|
||||||
# View image tool (display local images to user)
|
|
||||||
- name: view_image
|
|
||||||
group: file:read
|
|
||||||
use: src.tools.builtins:view_image_tool
|
|
||||||
|
|
||||||
# File operations tools
|
# File operations tools
|
||||||
- name: ls
|
- name: ls
|
||||||
group: file:read
|
group: file:read
|
||||||
|
|||||||
@@ -2,6 +2,29 @@ import base64
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def validate_image(image_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if an image file can be opened and is not corrupted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the image is valid and can be opened, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
img.verify() # Verify that it's a valid image
|
||||||
|
# Re-open to check if it can be fully loaded (verify() may not catch all issues)
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
img.load() # Force load the image data
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Image '{image_path}' is invalid or corrupted: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def generate_image(
|
def generate_image(
|
||||||
@@ -14,7 +37,19 @@ def generate_image(
|
|||||||
prompt = f.read()
|
prompt = f.read()
|
||||||
parts = []
|
parts = []
|
||||||
i = 0
|
i = 0
|
||||||
for reference_image in reference_images:
|
|
||||||
|
# Filter out invalid reference images
|
||||||
|
valid_reference_images = []
|
||||||
|
for ref_img in reference_images:
|
||||||
|
if validate_image(ref_img):
|
||||||
|
valid_reference_images.append(ref_img)
|
||||||
|
else:
|
||||||
|
print(f"Skipping invalid reference image: {ref_img}")
|
||||||
|
|
||||||
|
if len(valid_reference_images) < len(reference_images):
|
||||||
|
print(f"Note: {len(reference_images) - len(valid_reference_images)} reference image(s) were skipped due to validation failure.")
|
||||||
|
|
||||||
|
for reference_image in valid_reference_images:
|
||||||
i += 1
|
i += 1
|
||||||
with open(reference_image, "rb") as f:
|
with open(reference_image, "rb") as f:
|
||||||
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
@@ -41,6 +76,7 @@ def generate_image(
|
|||||||
"contents": [{"parts": [*parts, {"text": prompt}]}],
|
"contents": [{"parts": [*parts, {"text": prompt}]}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
json = response.json()
|
json = response.json()
|
||||||
parts: list[dict] = json["candidates"][0]["content"]["parts"]
|
parts: list[dict] = json["candidates"][0]["content"]["parts"]
|
||||||
image_parts = [part for part in parts if part.get("inlineData", False)]
|
image_parts = [part for part in parts if part.get("inlineData", False)]
|
||||||
@@ -92,5 +128,5 @@ if __name__ == "__main__":
|
|||||||
args.aspect_ratio,
|
args.aspect_ratio,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
print("Error while generating image.")
|
print(f"Error while generating image: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user