diff --git a/backend/src/agents/lead_agent/agent.py b/backend/src/agents/lead_agent/agent.py index 8c01e15..836cba1 100644 --- a/backend/src/agents/lead_agent/agent.py +++ b/backend/src/agents/lead_agent/agent.py @@ -7,6 +7,7 @@ from src.agents.middlewares.clarification_middleware import ClarificationMiddlew from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware from src.agents.middlewares.title_middleware import TitleMiddleware from src.agents.middlewares.uploads_middleware import UploadsMiddleware +from src.agents.middlewares.view_image_middleware import ViewImageMiddleware from src.agents.thread_state import ThreadState from src.config.summarization_config import get_summarization_config from src.models import create_chat_model @@ -174,6 +175,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r # UploadsMiddleware should be after ThreadDataMiddleware to access thread_id # SummarizationMiddleware should be early to reduce context before other processing # TodoListMiddleware should be before ClarificationMiddleware to allow todo management +# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM # ClarificationMiddleware should be last to intercept clarification requests after model calls def _build_middlewares(config: RunnableConfig): """Build middleware chain based on runtime configuration. @@ -197,7 +199,24 @@ def _build_middlewares(config: RunnableConfig): if todo_list_middleware is not None: middlewares.append(todo_list_middleware) - middlewares.extend([TitleMiddleware(), ClarificationMiddleware()]) + # Add TitleMiddleware + middlewares.append(TitleMiddleware()) + + # Add ViewImageMiddleware only if the current model supports vision + model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model") + from src.config import get_app_config + + app_config = get_app_config() + # If no model_name specified, use the first model (default) + if model_name is None and app_config.models: + model_name = app_config.models[0].name + + model_config = app_config.get_model_config(model_name) if model_name else None + if model_config is not None and model_config.supports_vision: + middlewares.append(ViewImageMiddleware()) + + # ClarificationMiddleware should always be last + middlewares.append(ClarificationMiddleware()) return middlewares diff --git a/backend/src/agents/middlewares/view_image_middleware.py b/backend/src/agents/middlewares/view_image_middleware.py new file mode 100644 index 0000000..404cf40 --- /dev/null +++ b/backend/src/agents/middlewares/view_image_middleware.py @@ -0,0 +1,221 @@ +"""Middleware for injecting image details into conversation before LLM call.""" + +from typing import NotRequired, override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langgraph.runtime import Runtime + +from src.agents.thread_state import ViewedImageData + + +class ViewImageMiddlewareState(AgentState): + """Compatible with the `ThreadState` schema.""" + + viewed_images: NotRequired[dict[str, ViewedImageData] | None] + + +class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]): + """Injects image details as a human message before LLM calls when view_image tools have completed. + + This middleware: + 1. Runs before each LLM call + 2. Checks if the last assistant message contains view_image tool calls + 3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages) + 4. If conditions are met, creates a human message with all viewed image details (including base64 data) + 5. Adds the message to state so the LLM can see and analyze the images + + This enables the LLM to automatically receive and analyze images that were loaded via view_image tool, + without requiring explicit user prompts to describe the images. + """ + + state_schema = ViewImageMiddlewareState + + def _get_last_assistant_message(self, messages: list) -> AIMessage | None: + """Get the last assistant message from the message list. + + Args: + messages: List of messages + + Returns: + Last AIMessage or None if not found + """ + for msg in reversed(messages): + if isinstance(msg, AIMessage): + return msg + return None + + def _has_view_image_tool(self, message: AIMessage) -> bool: + """Check if the assistant message contains view_image tool calls. + + Args: + message: Assistant message to check + + Returns: + True if message contains view_image tool calls + """ + if not hasattr(message, "tool_calls") or not message.tool_calls: + return False + + return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls) + + def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool: + """Check if all tool calls in the assistant message have been completed. + + Args: + messages: List of all messages + assistant_msg: The assistant message containing tool calls + + Returns: + True if all tool calls have corresponding ToolMessages + """ + if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls: + return False + + # Get all tool call IDs from the assistant message + tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")} + + # Find the index of the assistant message + try: + assistant_idx = messages.index(assistant_msg) + except ValueError: + return False + + # Get all ToolMessages after the assistant message + completed_tool_ids = set() + for msg in messages[assistant_idx + 1 :]: + if isinstance(msg, ToolMessage) and msg.tool_call_id: + completed_tool_ids.add(msg.tool_call_id) + + # Check if all tool calls have been completed + return tool_call_ids.issubset(completed_tool_ids) + + def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]: + """Create a formatted message with all viewed image details. + + Args: + state: Current state containing viewed_images + + Returns: + List of content blocks (text and images) for the HumanMessage + """ + viewed_images = state.get("viewed_images", {}) + if not viewed_images: + return ["No images have been viewed."] + + # Build the message with image information + content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}] + + for image_path, image_data in viewed_images.items(): + mime_type = image_data.get("mime_type", "unknown") + base64_data = image_data.get("base64", "") + + # Add text description + content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"}) + + # Add the actual image data so LLM can "see" it + if base64_data: + content_blocks.append( + { + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{base64_data}"}, + } + ) + + return content_blocks + + def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool: + """Determine if we should inject an image details message. + + Args: + state: Current state + + Returns: + True if we should inject the message + """ + messages = state.get("messages", []) + if not messages: + return False + + # Get the last assistant message + last_assistant_msg = self._get_last_assistant_message(messages) + if not last_assistant_msg: + return False + + # Check if it has view_image tool calls + if not self._has_view_image_tool(last_assistant_msg): + return False + + # Check if all tools have been completed + if not self._all_tools_completed(messages, last_assistant_msg): + return False + + # Check if we've already added an image details message + # Look for a human message after the last assistant message that contains image details + assistant_idx = messages.index(last_assistant_msg) + for msg in messages[assistant_idx + 1 :]: + if isinstance(msg, HumanMessage): + content_str = str(msg.content) + if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str: + # Already added, don't add again + return False + + return True + + def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None: + """Internal helper to inject image details message. + + Args: + state: Current state + + Returns: + State update with additional human message, or None if no update needed + """ + if not self._should_inject_image_message(state): + return None + + # Create the image details message with text and image content + image_content = self._create_image_details_message(state) + + # Create a new human message with mixed content (text + images) + human_msg = HumanMessage(content=image_content) + + print("[ViewImageMiddleware] Injecting image details message with images before LLM call") + + # Return state update with the new message + return {"messages": [human_msg]} + + @override + def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None: + """Inject image details message before LLM call if view_image tools have completed (sync version). + + This runs before each LLM call, checking if the previous turn included view_image + tool calls that have all completed. If so, it injects a human message with the image + details so the LLM can see and analyze the images. + + Args: + state: Current state + runtime: Runtime context (unused but required by interface) + + Returns: + State update with additional human message, or None if no update needed + """ + return self._inject_image_message(state) + + @override + async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None: + """Inject image details message before LLM call if view_image tools have completed (async version). + + This runs before each LLM call, checking if the previous turn included view_image + tool calls that have all completed. If so, it injects a human message with the image + details so the LLM can see and analyze the images. + + Args: + state: Current state + runtime: Runtime context (unused but required by interface) + + Returns: + State update with additional human message, or None if no update needed + """ + return self._inject_image_message(state) diff --git a/backend/src/agents/thread_state.py b/backend/src/agents/thread_state.py index 358adc5..2d87c3e 100644 --- a/backend/src/agents/thread_state.py +++ b/backend/src/agents/thread_state.py @@ -1,4 +1,4 @@ -from typing import NotRequired, TypedDict +from typing import Annotated, NotRequired, TypedDict from langchain.agents import AgentState @@ -13,10 +13,43 @@ class ThreadDataState(TypedDict): outputs_path: NotRequired[str | None] +class ViewedImageData(TypedDict): + base64: str + mime_type: str + + +def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]: + """Reducer for artifacts list - merges and deduplicates artifacts.""" + if existing is None: + return new or [] + if new is None: + return existing + # Use dict.fromkeys to deduplicate while preserving order + return list(dict.fromkeys(existing + new)) + + +def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]: + """Reducer for viewed_images dict - merges image dictionaries. + + Special case: If new is an empty dict {}, it clears the existing images. + This allows middlewares to clear the viewed_images state after processing. + """ + if existing is None: + return new or {} + if new is None: + return existing + # Special case: empty dict means clear all viewed images + if len(new) == 0: + return {} + # Merge dictionaries, new values override existing ones for same keys + return {**existing, **new} + + class ThreadState(AgentState): sandbox: NotRequired[SandboxState | None] thread_data: NotRequired[ThreadDataState | None] title: NotRequired[str | None] - artifacts: NotRequired[list[str] | None] + artifacts: Annotated[list[str], merge_artifacts] todos: NotRequired[list | None] uploaded_files: NotRequired[list[dict] | None] + viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type} diff --git a/backend/src/community/firecrawl/tools.py b/backend/src/community/firecrawl/tools.py index 3ca22f2..0bf46a6 100644 --- a/backend/src/community/firecrawl/tools.py +++ b/backend/src/community/firecrawl/tools.py @@ -70,4 +70,4 @@ def web_fetch_tool(url: str) -> str: except Exception as e: return f"Error: {str(e)}" - return f"# {title}\n\n{markdown_content}" + return f"# {title}\n\n{markdown_content[:4096]}" diff --git a/backend/src/community/jina_ai/tools.py b/backend/src/community/jina_ai/tools.py index c87b011..1a9cb41 100644 --- a/backend/src/community/jina_ai/tools.py +++ b/backend/src/community/jina_ai/tools.py @@ -25,4 +25,4 @@ def web_fetch_tool(url: str) -> str: timeout = config.model_extra.get("timeout") html_content = jina_client.crawl(url, return_format="html", timeout=timeout) article = readability_extractor.extract_article(html_content) - return article.to_markdown() + return article.to_markdown()[:4096] diff --git a/backend/src/community/tavily/tools.py b/backend/src/community/tavily/tools.py index 6a22d87..d3741d9 100644 --- a/backend/src/community/tavily/tools.py +++ b/backend/src/community/tavily/tools.py @@ -57,6 +57,6 @@ def web_fetch_tool(url: str) -> str: return f"Error: {res['failed_results'][0]['error']}" elif "results" in res and len(res["results"]) > 0: result = res["results"][0] - return f"# {result['title']}\n\n{result['raw_content']}" + return f"# {result['title']}\n\n{result['raw_content'][:4096]}" else: return "Error: No results found" diff --git a/backend/src/config/model_config.py b/backend/src/config/model_config.py index a505c8f..277de2e 100644 --- a/backend/src/config/model_config.py +++ b/backend/src/config/model_config.py @@ -18,3 +18,4 @@ class ModelConfig(BaseModel): default_factory=lambda: None, description="Extra settings to be passed to the model when thinking is enabled", ) + supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs") diff --git a/backend/src/models/factory.py b/backend/src/models/factory.py index 8189cee..c9517a0 100644 --- a/backend/src/models/factory.py +++ b/backend/src/models/factory.py @@ -29,6 +29,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, * "description", "supports_thinking", "when_thinking_enabled", + "supports_vision", }, ) if thinking_enabled and model_config.when_thinking_enabled is not None: diff --git a/backend/src/tools/builtins/__init__.py b/backend/src/tools/builtins/__init__.py index 7d3f5ab..50bbcd9 100644 --- a/backend/src/tools/builtins/__init__.py +++ b/backend/src/tools/builtins/__init__.py @@ -1,4 +1,5 @@ from .clarification_tool import ask_clarification_tool from .present_file_tool import present_file_tool +from .view_image_tool import view_image_tool -__all__ = ["present_file_tool", "ask_clarification_tool"] +__all__ = ["present_file_tool", "ask_clarification_tool", "view_image_tool"] diff --git a/backend/src/tools/builtins/present_file_tool.py b/backend/src/tools/builtins/present_file_tool.py index a388bd8..de5c41a 100644 --- a/backend/src/tools/builtins/present_file_tool.py +++ b/backend/src/tools/builtins/present_file_tool.py @@ -28,15 +28,12 @@ def present_file_tool( Notes: - You should call this tool after creating files and moving them to the `/mnt/user-data/outputs` directory. - - IMPORTANT: Do NOT call this tool in parallel with other tools. Call it separately. + - This tool can be safely called in parallel with other tools. State updates are handled by a reducer to prevent conflicts. Args: filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented. """ - existing_artifacts = runtime.state.get("artifacts") or [] - # Use dict.fromkeys to deduplicate while preserving order - new_artifacts = list(dict.fromkeys(existing_artifacts + filepaths)) - runtime.state["artifacts"] = new_artifacts + # The merge_artifacts reducer will handle merging and deduplication return Command( - update={"artifacts": new_artifacts, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]}, + update={"artifacts": filepaths, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]}, ) diff --git a/backend/src/tools/builtins/view_image_tool.py b/backend/src/tools/builtins/view_image_tool.py new file mode 100644 index 0000000..f979294 --- /dev/null +++ b/backend/src/tools/builtins/view_image_tool.py @@ -0,0 +1,94 @@ +import base64 +import mimetypes +from pathlib import Path +from typing import Annotated + +from langchain.tools import InjectedToolCallId, ToolRuntime, tool +from langchain_core.messages import ToolMessage +from langgraph.types import Command +from langgraph.typing import ContextT + +from src.agents.thread_state import ThreadState +from src.sandbox.tools import get_thread_data, replace_virtual_path + + +@tool("view_image", parse_docstring=True) +def view_image_tool( + runtime: ToolRuntime[ContextT, ThreadState], + image_path: str, + tool_call_id: Annotated[str, InjectedToolCallId], +) -> Command: + """Read an image file. + + Use this tool to read an image file and make it available for display. + + When to use the view_image tool: + - When you need to view an image file. + + When NOT to use the view_image tool: + - For non-image files (use present_files instead) + - For multiple files at once (use present_files instead) + + Args: + image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp. + """ + # Replace virtual path with actual path + # /mnt/user-data/* paths are mapped to thread-specific directories + thread_data = get_thread_data(runtime) + actual_path = replace_virtual_path(image_path, thread_data) + + # Validate that the path is absolute + path = Path(actual_path) + if not path.is_absolute(): + return Command( + update={"messages": [ToolMessage(f"Error: Path must be absolute, got: {image_path}", tool_call_id=tool_call_id)]}, + ) + + # Validate that the file exists + if not path.exists(): + return Command( + update={"messages": [ToolMessage(f"Error: Image file not found: {image_path}", tool_call_id=tool_call_id)]}, + ) + + # Validate that it's a file (not a directory) + if not path.is_file(): + return Command( + update={"messages": [ToolMessage(f"Error: Path is not a file: {image_path}", tool_call_id=tool_call_id)]}, + ) + + # Validate image extension + valid_extensions = {".jpg", ".jpeg", ".png", ".webp"} + if path.suffix.lower() not in valid_extensions: + return Command( + update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(valid_extensions)}", tool_call_id=tool_call_id)]}, + ) + + # Detect MIME type from file extension + mime_type, _ = mimetypes.guess_type(actual_path) + if mime_type is None: + # Fallback to default MIME types for common image formats + extension_to_mime = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", + } + mime_type = extension_to_mime.get(path.suffix.lower(), "application/octet-stream") + + # Read image file and convert to base64 + try: + with open(actual_path, "rb") as f: + image_data = f.read() + image_base64 = base64.b64encode(image_data).decode("utf-8") + except Exception as e: + return Command( + update={"messages": [ToolMessage(f"Error reading image file: {str(e)}", tool_call_id=tool_call_id)]}, + ) + + # Update viewed_images in state + # The merge_viewed_images reducer will handle merging with existing images + new_viewed_images = {image_path: {"base64": image_base64, "mime_type": mime_type}} + + return Command( + update={"viewed_images": new_viewed_images, "messages": [ToolMessage("Successfully read image", tool_call_id=tool_call_id)]}, + ) diff --git a/config.example.yaml b/config.example.yaml index 7ec8e89..132d74e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -21,6 +21,7 @@ models: api_key: $OPENAI_API_KEY # Use environment variable max_tokens: 4096 temperature: 0.7 + supports_vision: true # Enable vision support for view_image tool # Example: Anthropic Claude model # - name: claude-3-5-sonnet @@ -29,6 +30,7 @@ models: # model: claude-3-5-sonnet-20241022 # api_key: $ANTHROPIC_API_KEY # max_tokens: 8192 + # supports_vision: true # Enable vision support for view_image tool # Example: DeepSeek model (with thinking support) # - name: deepseek-v3 @@ -38,6 +40,7 @@ models: # api_key: $DEEPSEEK_API_KEY # max_tokens: 16384 # supports_thinking: true + # supports_vision: false # DeepSeek V3 does not support vision # when_thinking_enabled: # extra_body: # thinking: @@ -51,6 +54,7 @@ models: # api_base: https://ark.cn-beijing.volces.com/api/v3 # api_key: $VOLCENGINE_API_KEY # supports_thinking: true + # supports_vision: false # Check your specific model's capabilities # when_thinking_enabled: # extra_body: # thinking: @@ -65,6 +69,7 @@ models: # api_key: $MOONSHOT_API_KEY # max_tokens: 32768 # supports_thinking: true + # supports_vision: false # Check your specific model's capabilities # when_thinking_enabled: # extra_body: # thinking: @@ -107,6 +112,11 @@ tools: use: src.community.image_search.tools:image_search_tool max_results: 5 + # View image tool (display local images to user) + - name: view_image + group: file:read + use: src.tools.builtins:view_image_tool + # File operations tools - name: ls group: file:read