2025-04-17 11:34:42 +08:00
|
|
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
2025-10-16 17:38:18 +08:00
|
|
|
import asyncio
|
2025-04-18 15:28:31 +08:00
|
|
|
import base64
|
2025-04-13 21:14:31 +08:00
|
|
|
import json
|
|
|
|
|
import logging
|
2025-10-16 17:38:18 +08:00
|
|
|
import os
|
feat: implement tool-specific interrupts for create_react_agent (#572) (#659)
* feat: implement tool-specific interrupts for create_react_agent (#572)
Add selective tool interrupt capability allowing interrupts before specific tools
rather than all tools. Users can now configure which tools trigger interrupts via
the interrupt_before_tools parameter.
Changes:
- Create ToolInterceptor class to handle tool-specific interrupt logic
- Add interrupt_before_tools parameter to create_agent() function
- Extend Configuration with interrupt_before_tools field
- Add interrupt_before_tools to ChatRequest API
- Update nodes.py to pass interrupt configuration to agents
- Update app.py workflow to support tool interrupt configuration
- Add comprehensive unit tests for tool interceptor
Features:
- Selective tool interrupts: interrupt only specific tools by name
- Approval keywords: recognize user approval (approved, proceed, accept, etc.)
- Backward compatible: optional parameter, existing code unaffected
- Flexible: works with default tools and MCP-powered tools
- Works with existing resume mechanism for seamless workflow
Example usage:
request = ChatRequest(
messages=[...],
interrupt_before_tools=['db_tool', 'sensitive_api']
)
* test: add comprehensive integration tests for tool-specific interrupts (#572)
Add 24 integration tests covering all aspects of the tool interceptor feature:
Test Coverage:
- Agent creation with tool interrupts
- Configuration support (with/without interrupts)
- ChatRequest API integration
- Multiple tools with selective interrupts
- User approval/rejection flows
- Tool wrapping and functionality preservation
- Error handling and edge cases
- Approval keyword recognition
- Complex tool inputs
- Logging and monitoring
All tests pass with 100% coverage of tool interceptor functionality.
Tests verify:
✓ Selective tool interrupts work correctly
✓ Only specified tools trigger interrupts
✓ Non-matching tools execute normally
✓ User feedback is properly parsed
✓ Tool functionality is preserved after wrapping
✓ Error handling works as expected
✓ Configuration options are properly respected
✓ Logging provides useful debugging info
* fix: mock get_llm_by_type in agent creation test
Fix test_agent_creation_with_tool_interrupts which was failing because
get_llm_by_type() was being called before create_react_agent was mocked.
Changes:
- Add mock for get_llm_by_type in test
- Use context manager composition for multiple patches
- Test now passes and validates tool wrapping correctly
All 24 integration tests now pass successfully.
* refactor: use mock assertion methods for consistent and clearer error messages
Update integration tests to use mock assertion methods instead of direct
attribute checking for consistency and clearer error messages:
Changes:
- Replace 'assert mock_interrupt.called' with 'mock_interrupt.assert_called()'
- Replace 'assert not mock_interrupt.called' with 'mock_interrupt.assert_not_called()'
Benefits:
- Consistent with pytest-mock and unittest.mock best practices
- Clearer error messages when assertions fail
- Better IDE autocompletion support
- More professional test code
All 42 tests pass with improved assertion patterns.
* refactor: use default_factory for interrupt_before_tools consistency
Improve consistency between ChatRequest and Configuration implementations:
Changes:
- ChatRequest.interrupt_before_tools: Use Field(default_factory=list) instead of Optional[None]
- Remove unnecessary 'or []' conversion in app.py line 505
- Aligns with Configuration.interrupt_before_tools implementation pattern
- No functional changes - all tests still pass
Benefits:
- Consistent field definition across codebase
- Simpler and cleaner code
- Reduced chance of None/empty list bugs
- Better alignment with Pydantic best practices
All 42 tests passing.
* refactor: improve tool input formatting in interrupt messages
Enhance tool input representation for better readability in interrupt messages:
Changes:
- Add json import for better formatting
- Create _format_tool_input() static method with JSON serialization
- Use JSON formatting for dicts, lists, tuples with indent=2
- Fall back to str() for non-serializable types
- Handle None input specially (returns 'No input')
- Improve interrupt message formatting with better spacing
Benefits:
- Complex tool inputs now display as readable JSON
- Nested structures are properly indented and visible
- Better user experience when reviewing tool inputs before approval
- Handles edge cases gracefully with fallbacks
- Improved logging output for debugging
Example improvements:
Before: {'query': 'SELECT...', 'limit': 10, 'nested': {'key': 'value'}}
After:
{
"query": "SELECT...",
"limit": 10,
"nested": {
"key": "value"
}
}
All 42 tests still passing.
* test: add comprehensive unit tests for tool input formatting
2025-10-26 09:47:03 +08:00
|
|
|
from typing import Annotated, Any, List, Optional, cast
|
2025-04-13 21:14:31 +08:00
|
|
|
from uuid import uuid4
|
|
|
|
|
|
2025-05-28 14:13:46 +08:00
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
2025-04-13 21:14:31 +08:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-04-19 17:37:40 +08:00
|
|
|
from fastapi.responses import Response, StreamingResponse
|
2025-07-04 08:27:20 +08:00
|
|
|
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
2025-08-16 21:03:12 +08:00
|
|
|
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
|
|
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
2025-09-16 20:30:45 +08:00
|
|
|
from langgraph.store.memory import InMemoryStore
|
|
|
|
|
from langgraph.types import Command
|
2025-08-16 21:03:12 +08:00
|
|
|
from psycopg_pool import AsyncConnectionPool
|
2025-04-13 21:14:31 +08:00
|
|
|
|
2025-09-12 22:20:55 +08:00
|
|
|
from src.config.configuration import get_recursion_limit
|
|
|
|
|
from src.config.loader import get_bool_env, get_str_env
|
2025-06-07 20:48:39 +08:00
|
|
|
from src.config.report_style import ReportStyle
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
2025-04-22 15:33:53 +08:00
|
|
|
from src.graph.builder import build_graph_with_memory
|
2025-09-16 20:30:45 +08:00
|
|
|
from src.graph.checkpoint import chat_stream_message
|
2025-10-22 22:49:07 +08:00
|
|
|
from src.graph.utils import (
|
|
|
|
|
build_clarified_topic_from_history,
|
|
|
|
|
reconstruct_clarification_history,
|
|
|
|
|
)
|
2025-07-04 08:27:20 +08:00
|
|
|
from src.llms.llm import get_configured_llm_models
|
2025-04-19 17:37:40 +08:00
|
|
|
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
2025-04-21 16:43:06 +08:00
|
|
|
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
2025-06-08 19:41:59 +08:00
|
|
|
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
|
2025-07-04 08:27:20 +08:00
|
|
|
from src.prose.graph.builder import build_graph as build_prose_graph
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.rag.builder import build_retriever
|
2025-09-12 22:20:55 +08:00
|
|
|
from src.rag.milvus import load_examples
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.rag.retriever import Resource
|
2025-04-19 22:11:41 +08:00
|
|
|
from src.server.chat_request import (
|
|
|
|
|
ChatRequest,
|
2025-06-08 19:41:59 +08:00
|
|
|
EnhancePromptRequest,
|
2025-04-19 22:11:41 +08:00
|
|
|
GeneratePodcastRequest,
|
2025-04-21 16:43:06 +08:00
|
|
|
GeneratePPTRequest,
|
2025-04-26 23:12:13 +08:00
|
|
|
GenerateProseRequest,
|
2025-04-19 22:11:41 +08:00
|
|
|
TTSRequest,
|
|
|
|
|
)
|
2025-07-04 08:27:20 +08:00
|
|
|
from src.server.config_request import ConfigResponse
|
2025-04-23 14:38:04 +08:00
|
|
|
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
|
|
|
|
from src.server.mcp_utils import load_mcp_tools
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.server.rag_request import (
|
|
|
|
|
RAGConfigResponse,
|
|
|
|
|
RAGResourceRequest,
|
|
|
|
|
RAGResourcesResponse,
|
|
|
|
|
)
|
2025-04-18 15:28:31 +08:00
|
|
|
from src.tools import VolcengineTTS
|
2025-08-16 21:03:12 +08:00
|
|
|
from src.utils.json_utils import sanitize_args
|
2025-10-27 20:57:23 +08:00
|
|
|
from src.utils.log_sanitizer import (
|
|
|
|
|
sanitize_agent_name,
|
|
|
|
|
sanitize_log_input,
|
|
|
|
|
sanitize_thread_id,
|
|
|
|
|
sanitize_tool_name,
|
|
|
|
|
sanitize_user_content,
|
|
|
|
|
)
|
2025-04-13 21:14:31 +08:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2025-10-16 17:38:18 +08:00
|
|
|
# Configure Windows event loop policy for PostgreSQL compatibility
|
|
|
|
|
# On Windows, psycopg requires a selector-based event loop, not the default ProactorEventLoop
|
|
|
|
|
if os.name == "nt":
|
|
|
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
|
|
2025-06-05 09:23:42 +08:00
|
|
|
INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
app = FastAPI(
|
2025-05-08 08:59:18 +08:00
|
|
|
title="DeerFlow API",
|
2025-04-17 11:17:03 +08:00
|
|
|
description="API for Deer",
|
2025-04-13 21:14:31 +08:00
|
|
|
version="0.1.0",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add CORS middleware
|
2025-07-18 18:04:03 +08:00
|
|
|
# It's recommended to load the allowed origins from an environment variable
|
|
|
|
|
# for better security and flexibility across different environments.
|
2025-08-16 21:03:12 +08:00
|
|
|
allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000")
|
2025-07-18 18:04:03 +08:00
|
|
|
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")]
|
|
|
|
|
|
2025-07-20 11:38:18 +08:00
|
|
|
logger.info(f"Allowed origins: {allowed_origins}")
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
2025-07-18 18:04:03 +08:00
|
|
|
allow_origins=allowed_origins, # Restrict to specific origins
|
2025-04-13 21:14:31 +08:00
|
|
|
allow_credentials=True,
|
2025-07-20 11:38:18 +08:00
|
|
|
allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods
|
2025-07-20 14:10:46 +08:00
|
|
|
allow_headers=["*"], # Now allow all headers, but can be restricted further
|
2025-04-13 21:14:31 +08:00
|
|
|
)
|
2025-09-12 22:20:55 +08:00
|
|
|
|
|
|
|
|
# Load examples into Milvus if configured
|
|
|
|
|
load_examples()
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
in_memory_store = InMemoryStore()
|
2025-04-22 15:33:53 +08:00
|
|
|
graph = build_graph_with_memory()
|
2025-04-13 21:14:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/chat/stream")
|
|
|
|
|
async def chat_stream(request: ChatRequest):
|
2025-07-19 08:39:42 +08:00
|
|
|
# Check if MCP server configuration is enabled
|
2025-08-16 21:03:12 +08:00
|
|
|
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
|
2025-07-19 08:39:42 +08:00
|
|
|
|
|
|
|
|
# Validate MCP settings if provided
|
|
|
|
|
if request.mcp_settings and not mcp_enabled:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=403,
|
|
|
|
|
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
thread_id = request.thread_id
|
|
|
|
|
if thread_id == "__default__":
|
|
|
|
|
thread_id = str(uuid4())
|
2025-08-16 21:03:12 +08:00
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
return StreamingResponse(
|
|
|
|
|
_astream_workflow_generator(
|
|
|
|
|
request.model_dump()["messages"],
|
|
|
|
|
thread_id,
|
2025-05-28 14:13:46 +08:00
|
|
|
request.resources,
|
2025-04-13 21:14:31 +08:00
|
|
|
request.max_plan_iterations,
|
|
|
|
|
request.max_step_num,
|
2025-05-17 22:23:52 -07:00
|
|
|
request.max_search_results,
|
2025-04-14 18:01:50 +08:00
|
|
|
request.auto_accepted_plan,
|
2025-04-15 16:36:02 +08:00
|
|
|
request.interrupt_feedback,
|
2025-07-19 08:39:42 +08:00
|
|
|
request.mcp_settings if mcp_enabled else {},
|
2025-04-27 20:15:42 +08:00
|
|
|
request.enable_background_investigation,
|
2025-06-07 20:48:39 +08:00
|
|
|
request.report_style,
|
2025-06-14 13:12:43 +08:00
|
|
|
request.enable_deep_thinking,
|
2025-10-13 22:35:57 -07:00
|
|
|
request.enable_clarification,
|
|
|
|
|
request.max_clarification_rounds,
|
2025-10-24 16:31:19 +08:00
|
|
|
request.locale,
|
feat: implement tool-specific interrupts for create_react_agent (#572) (#659)
* feat: implement tool-specific interrupts for create_react_agent (#572)
Add selective tool interrupt capability allowing interrupts before specific tools
rather than all tools. Users can now configure which tools trigger interrupts via
the interrupt_before_tools parameter.
Changes:
- Create ToolInterceptor class to handle tool-specific interrupt logic
- Add interrupt_before_tools parameter to create_agent() function
- Extend Configuration with interrupt_before_tools field
- Add interrupt_before_tools to ChatRequest API
- Update nodes.py to pass interrupt configuration to agents
- Update app.py workflow to support tool interrupt configuration
- Add comprehensive unit tests for tool interceptor
Features:
- Selective tool interrupts: interrupt only specific tools by name
- Approval keywords: recognize user approval (approved, proceed, accept, etc.)
- Backward compatible: optional parameter, existing code unaffected
- Flexible: works with default tools and MCP-powered tools
- Works with existing resume mechanism for seamless workflow
Example usage:
request = ChatRequest(
messages=[...],
interrupt_before_tools=['db_tool', 'sensitive_api']
)
* test: add comprehensive integration tests for tool-specific interrupts (#572)
Add 24 integration tests covering all aspects of the tool interceptor feature:
Test Coverage:
- Agent creation with tool interrupts
- Configuration support (with/without interrupts)
- ChatRequest API integration
- Multiple tools with selective interrupts
- User approval/rejection flows
- Tool wrapping and functionality preservation
- Error handling and edge cases
- Approval keyword recognition
- Complex tool inputs
- Logging and monitoring
All tests pass with 100% coverage of tool interceptor functionality.
Tests verify:
✓ Selective tool interrupts work correctly
✓ Only specified tools trigger interrupts
✓ Non-matching tools execute normally
✓ User feedback is properly parsed
✓ Tool functionality is preserved after wrapping
✓ Error handling works as expected
✓ Configuration options are properly respected
✓ Logging provides useful debugging info
* fix: mock get_llm_by_type in agent creation test
Fix test_agent_creation_with_tool_interrupts which was failing because
get_llm_by_type() was being called before create_react_agent was mocked.
Changes:
- Add mock for get_llm_by_type in test
- Use context manager composition for multiple patches
- Test now passes and validates tool wrapping correctly
All 24 integration tests now pass successfully.
* refactor: use mock assertion methods for consistent and clearer error messages
Update integration tests to use mock assertion methods instead of direct
attribute checking for consistency and clearer error messages:
Changes:
- Replace 'assert mock_interrupt.called' with 'mock_interrupt.assert_called()'
- Replace 'assert not mock_interrupt.called' with 'mock_interrupt.assert_not_called()'
Benefits:
- Consistent with pytest-mock and unittest.mock best practices
- Clearer error messages when assertions fail
- Better IDE autocompletion support
- More professional test code
All 42 tests pass with improved assertion patterns.
* refactor: use default_factory for interrupt_before_tools consistency
Improve consistency between ChatRequest and Configuration implementations:
Changes:
- ChatRequest.interrupt_before_tools: Use Field(default_factory=list) instead of Optional[None]
- Remove unnecessary 'or []' conversion in app.py line 505
- Aligns with Configuration.interrupt_before_tools implementation pattern
- No functional changes - all tests still pass
Benefits:
- Consistent field definition across codebase
- Simpler and cleaner code
- Reduced chance of None/empty list bugs
- Better alignment with Pydantic best practices
All 42 tests passing.
* refactor: improve tool input formatting in interrupt messages
Enhance tool input representation for better readability in interrupt messages:
Changes:
- Add json import for better formatting
- Create _format_tool_input() static method with JSON serialization
- Use JSON formatting for dicts, lists, tuples with indent=2
- Fall back to str() for non-serializable types
- Handle None input specially (returns 'No input')
- Improve interrupt message formatting with better spacing
Benefits:
- Complex tool inputs now display as readable JSON
- Nested structures are properly indented and visible
- Better user experience when reviewing tool inputs before approval
- Handles edge cases gracefully with fallbacks
- Improved logging output for debugging
Example improvements:
Before: {'query': 'SELECT...', 'limit': 10, 'nested': {'key': 'value'}}
After:
{
"query": "SELECT...",
"limit": 10,
"nested": {
"key": "value"
}
}
All 42 tests still passing.
* test: add comprehensive unit tests for tool input formatting
2025-10-26 09:47:03 +08:00
|
|
|
request.interrupt_before_tools,
|
2025-04-13 21:14:31 +08:00
|
|
|
),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-10-24 22:26:25 +08:00
|
|
|
def _validate_tool_call_chunks(tool_call_chunks):
|
|
|
|
|
"""Validate and log tool call chunk structure for debugging."""
|
|
|
|
|
if not tool_call_chunks:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Validating tool_call_chunks: count={len(tool_call_chunks)}")
|
|
|
|
|
|
|
|
|
|
indices_seen = set()
|
|
|
|
|
tool_ids_seen = set()
|
|
|
|
|
|
|
|
|
|
for i, chunk in enumerate(tool_call_chunks):
|
|
|
|
|
index = chunk.get("index")
|
|
|
|
|
tool_id = chunk.get("id")
|
|
|
|
|
name = chunk.get("name", "")
|
|
|
|
|
has_args = "args" in chunk
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Chunk {i}: index={index}, id={tool_id}, name={name}, "
|
|
|
|
|
f"has_args={has_args}, type={chunk.get('type')}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if index is not None:
|
|
|
|
|
indices_seen.add(index)
|
|
|
|
|
if tool_id:
|
|
|
|
|
tool_ids_seen.add(tool_id)
|
|
|
|
|
|
|
|
|
|
if len(indices_seen) > 1:
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Multiple indices detected: {sorted(indices_seen)} - "
|
|
|
|
|
f"This may indicate consecutive tool calls"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
def _process_tool_call_chunks(tool_call_chunks):
|
2025-10-24 22:26:25 +08:00
|
|
|
"""
|
|
|
|
|
Process tool call chunks with proper index-based grouping.
|
|
|
|
|
|
|
|
|
|
This function handles the concatenation of tool call chunks that belong
|
|
|
|
|
to the same tool call (same index) while properly segregating chunks
|
|
|
|
|
from different tool calls (different indices).
|
|
|
|
|
|
|
|
|
|
The issue: In streaming, LangChain's ToolCallChunk concatenates string
|
|
|
|
|
attributes (name, args) when chunks have the same index. We need to:
|
|
|
|
|
1. Group chunks by index
|
|
|
|
|
2. Detect index collisions with different tool names
|
|
|
|
|
3. Accumulate arguments for the same index
|
|
|
|
|
4. Return properly segregated tool calls
|
|
|
|
|
"""
|
|
|
|
|
if not tool_call_chunks:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
_validate_tool_call_chunks(tool_call_chunks)
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
chunks = []
|
2025-10-24 22:26:25 +08:00
|
|
|
chunk_by_index = {} # Group chunks by index to handle streaming accumulation
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
for chunk in tool_call_chunks:
|
2025-10-24 22:26:25 +08:00
|
|
|
index = chunk.get("index")
|
|
|
|
|
chunk_id = chunk.get("id")
|
|
|
|
|
|
|
|
|
|
if index is not None:
|
|
|
|
|
# Create or update entry for this index
|
|
|
|
|
if index not in chunk_by_index:
|
|
|
|
|
chunk_by_index[index] = {
|
|
|
|
|
"name": "",
|
|
|
|
|
"args": "",
|
|
|
|
|
"id": chunk_id or "",
|
|
|
|
|
"index": index,
|
|
|
|
|
"type": chunk.get("type", ""),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Validate and accumulate tool name
|
|
|
|
|
chunk_name = chunk.get("name", "")
|
|
|
|
|
if chunk_name:
|
|
|
|
|
stored_name = chunk_by_index[index]["name"]
|
|
|
|
|
|
|
|
|
|
# Check for index collision with different tool names
|
|
|
|
|
if stored_name and stored_name != chunk_name:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Tool name mismatch detected at index {index}: "
|
|
|
|
|
f"'{stored_name}' != '{chunk_name}'. "
|
|
|
|
|
f"This may indicate a streaming artifact or consecutive tool calls "
|
|
|
|
|
f"with the same index assignment."
|
|
|
|
|
)
|
|
|
|
|
# Keep the first name to prevent concatenation
|
|
|
|
|
else:
|
|
|
|
|
chunk_by_index[index]["name"] = chunk_name
|
|
|
|
|
|
|
|
|
|
# Update ID if new one provided
|
|
|
|
|
if chunk_id and not chunk_by_index[index]["id"]:
|
|
|
|
|
chunk_by_index[index]["id"] = chunk_id
|
|
|
|
|
|
|
|
|
|
# Accumulate arguments
|
|
|
|
|
if chunk.get("args"):
|
|
|
|
|
chunk_by_index[index]["args"] += chunk.get("args", "")
|
|
|
|
|
else:
|
|
|
|
|
# Handle chunks without explicit index (edge case)
|
|
|
|
|
logger.debug(f"Chunk without index encountered: {chunk}")
|
|
|
|
|
chunks.append({
|
2025-08-16 21:03:12 +08:00
|
|
|
"name": chunk.get("name", ""),
|
|
|
|
|
"args": sanitize_args(chunk.get("args", "")),
|
|
|
|
|
"id": chunk.get("id", ""),
|
2025-10-24 22:26:25 +08:00
|
|
|
"index": 0,
|
2025-08-16 21:03:12 +08:00
|
|
|
"type": chunk.get("type", ""),
|
2025-10-24 22:26:25 +08:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Convert indexed chunks to list, sorted by index for proper order
|
|
|
|
|
for index in sorted(chunk_by_index.keys()):
|
|
|
|
|
chunk_data = chunk_by_index[index]
|
|
|
|
|
chunk_data["args"] = sanitize_args(chunk_data["args"])
|
|
|
|
|
chunks.append(chunk_data)
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Processed tool call: index={index}, name={chunk_data['name']}, "
|
|
|
|
|
f"id={chunk_data['id']}"
|
2025-08-16 21:03:12 +08:00
|
|
|
)
|
2025-10-24 22:26:25 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_agent_name(agent, message_metadata):
|
|
|
|
|
"""Extract agent name from agent tuple."""
|
|
|
|
|
agent_name = "unknown"
|
|
|
|
|
if agent and len(agent) > 0:
|
|
|
|
|
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
|
|
|
|
|
else:
|
|
|
|
|
agent_name = message_metadata.get("langgraph_node", "unknown")
|
|
|
|
|
return agent_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_event_stream_message(
|
|
|
|
|
message_chunk, message_metadata, thread_id, agent_name
|
|
|
|
|
):
|
|
|
|
|
"""Create base event stream message."""
|
2025-10-20 23:10:58 +08:00
|
|
|
content = message_chunk.content
|
|
|
|
|
if not isinstance(content, str):
|
|
|
|
|
content = json.dumps(content, ensure_ascii=False)
|
2025-10-22 22:49:07 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
event_stream_message = {
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"agent": agent_name,
|
|
|
|
|
"id": message_chunk.id,
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"checkpoint_ns": message_metadata.get("checkpoint_ns", ""),
|
|
|
|
|
"langgraph_node": message_metadata.get("langgraph_node", ""),
|
|
|
|
|
"langgraph_path": message_metadata.get("langgraph_path", ""),
|
|
|
|
|
"langgraph_step": message_metadata.get("langgraph_step", ""),
|
2025-10-20 23:10:58 +08:00
|
|
|
"content": content,
|
2025-08-16 21:03:12 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Add optional fields
|
|
|
|
|
if message_chunk.additional_kwargs.get("reasoning_content"):
|
|
|
|
|
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
|
|
|
|
|
"reasoning_content"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if message_chunk.response_metadata.get("finish_reason"):
|
|
|
|
|
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
|
|
|
|
"finish_reason"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return event_stream_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_interrupt_event(thread_id, event_data):
|
|
|
|
|
"""Create interrupt event."""
|
|
|
|
|
return _make_event(
|
|
|
|
|
"interrupt",
|
|
|
|
|
{
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"id": event_data["__interrupt__"][0].ns[0],
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": event_data["__interrupt__"][0].value,
|
|
|
|
|
"finish_reason": "interrupt",
|
|
|
|
|
"options": [
|
|
|
|
|
{"text": "Edit plan", "value": "edit_plan"},
|
|
|
|
|
{"text": "Start research", "value": "accepted"},
|
|
|
|
|
],
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_initial_messages(message, thread_id):
|
|
|
|
|
"""Process initial messages and yield formatted events."""
|
|
|
|
|
json_data = json.dumps(
|
|
|
|
|
{
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"id": "run--" + message.get("id", uuid4().hex),
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": message.get("content", ""),
|
|
|
|
|
},
|
|
|
|
|
ensure_ascii=False,
|
|
|
|
|
separators=(",", ":"),
|
|
|
|
|
)
|
|
|
|
|
chat_stream_message(
|
|
|
|
|
thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
|
|
|
|
"""Process a single message chunk and yield appropriate events."""
|
2025-10-27 20:57:23 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
agent_name = _get_agent_name(agent, message_metadata)
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_agent_name = sanitize_agent_name(agent_name)
|
|
|
|
|
safe_thread_id = sanitize_thread_id(thread_id)
|
|
|
|
|
safe_agent = sanitize_agent_name(agent)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] _process_message_chunk started for agent={safe_agent_name}")
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Extracted agent_name: {safe_agent_name}")
|
2025-10-27 08:21:30 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
event_stream_message = _create_event_stream_message(
|
|
|
|
|
message_chunk, message_metadata, thread_id, agent_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(message_chunk, ToolMessage):
|
|
|
|
|
# Tool Message - Return the result of the tool call
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Processing ToolMessage")
|
2025-10-24 22:26:25 +08:00
|
|
|
tool_call_id = message_chunk.tool_call_id
|
|
|
|
|
event_stream_message["tool_call_id"] = tool_call_id
|
|
|
|
|
|
|
|
|
|
# Validate tool_call_id for debugging
|
|
|
|
|
if tool_call_id:
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_tool_id = sanitize_log_input(tool_call_id, max_length=100)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] ToolMessage with tool_call_id: {safe_tool_id}")
|
2025-10-24 22:26:25 +08:00
|
|
|
else:
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.warning(f"[{safe_thread_id}] ToolMessage received without tool_call_id")
|
2025-10-24 22:26:25 +08:00
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Yielding tool_call_result event")
|
2025-08-16 21:03:12 +08:00
|
|
|
yield _make_event("tool_call_result", event_stream_message)
|
|
|
|
|
elif isinstance(message_chunk, AIMessageChunk):
|
|
|
|
|
# AI Message - Raw message tokens
|
2025-10-27 20:57:23 +08:00
|
|
|
has_tool_calls = bool(message_chunk.tool_calls)
|
|
|
|
|
has_chunks = bool(message_chunk.tool_call_chunks)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Processing AIMessageChunk, tool_calls={has_tool_calls}, tool_call_chunks={has_chunks}")
|
2025-10-27 08:21:30 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
if message_chunk.tool_calls:
|
2025-10-24 22:26:25 +08:00
|
|
|
# AI Message - Tool Call (complete tool calls)
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_tool_names = [sanitize_tool_name(tc.get('name', 'unknown')) for tc in message_chunk.tool_calls]
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk has complete tool_calls: {safe_tool_names}")
|
2025-08-16 21:03:12 +08:00
|
|
|
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
2025-10-24 22:26:25 +08:00
|
|
|
|
|
|
|
|
# Process tool_call_chunks with proper index-based grouping
|
|
|
|
|
processed_chunks = _process_tool_call_chunks(
|
2025-08-16 21:03:12 +08:00
|
|
|
message_chunk.tool_call_chunks
|
|
|
|
|
)
|
2025-10-24 22:26:25 +08:00
|
|
|
if processed_chunks:
|
|
|
|
|
event_stream_message["tool_call_chunks"] = processed_chunks
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
2025-10-24 22:26:25 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Tool calls: {safe_tool_names}, "
|
2025-10-24 22:26:25 +08:00
|
|
|
f"Processed chunks: {len(processed_chunks)}"
|
|
|
|
|
)
|
|
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Yielding tool_calls event")
|
2025-08-16 21:03:12 +08:00
|
|
|
yield _make_event("tool_calls", event_stream_message)
|
|
|
|
|
elif message_chunk.tool_call_chunks:
|
2025-10-24 22:26:25 +08:00
|
|
|
# AI Message - Tool Call Chunks (streaming)
|
2025-10-27 20:57:23 +08:00
|
|
|
chunks_count = len(message_chunk.tool_call_chunks)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk has streaming tool_call_chunks: {chunks_count} chunks")
|
2025-10-24 22:26:25 +08:00
|
|
|
processed_chunks = _process_tool_call_chunks(
|
2025-08-16 21:03:12 +08:00
|
|
|
message_chunk.tool_call_chunks
|
|
|
|
|
)
|
2025-10-24 22:26:25 +08:00
|
|
|
|
|
|
|
|
# Emit separate events for chunks with different indices (tool call boundaries)
|
|
|
|
|
if processed_chunks:
|
|
|
|
|
prev_chunk = None
|
|
|
|
|
for chunk in processed_chunks:
|
|
|
|
|
current_index = chunk.get("index")
|
|
|
|
|
|
|
|
|
|
# Log index transitions to detect tool call boundaries
|
|
|
|
|
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
2025-10-27 20:57:23 +08:00
|
|
|
prev_name = sanitize_tool_name(prev_chunk.get('name'))
|
|
|
|
|
curr_name = sanitize_tool_name(chunk.get('name'))
|
2025-10-24 22:26:25 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Tool call boundary detected: "
|
|
|
|
|
f"index {prev_chunk.get('index')} ({prev_name}) -> "
|
|
|
|
|
f"{current_index} ({curr_name})"
|
2025-10-24 22:26:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
prev_chunk = chunk
|
|
|
|
|
|
|
|
|
|
# Include all processed chunks in the event
|
|
|
|
|
event_stream_message["tool_call_chunks"] = processed_chunks
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
2025-10-24 22:26:25 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
|
|
|
|
|
f"{safe_chunk_names}"
|
2025-10-24 22:26:25 +08:00
|
|
|
)
|
|
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Yielding tool_call_chunks event")
|
2025-08-16 21:03:12 +08:00
|
|
|
yield _make_event("tool_call_chunks", event_stream_message)
|
|
|
|
|
else:
|
|
|
|
|
# AI Message - Raw message tokens
|
2025-10-27 20:57:23 +08:00
|
|
|
content_len = len(message_chunk.content) if isinstance(message_chunk.content, str) else 0
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk is raw message tokens, content_len={content_len}")
|
2025-08-16 21:03:12 +08:00
|
|
|
yield _make_event("message_chunk", event_stream_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _stream_graph_events(
|
|
|
|
|
graph_instance, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
"""Stream events from the graph and process them."""
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_thread_id = sanitize_thread_id(thread_id)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Starting graph event stream with agent nodes")
|
2025-09-14 21:20:25 +08:00
|
|
|
try:
|
2025-10-27 08:21:30 +08:00
|
|
|
event_count = 0
|
2025-09-14 21:20:25 +08:00
|
|
|
async for agent, _, event_data in graph_instance.astream(
|
|
|
|
|
workflow_input,
|
|
|
|
|
config=workflow_config,
|
|
|
|
|
stream_mode=["messages", "updates"],
|
|
|
|
|
subgraphs=True,
|
|
|
|
|
):
|
2025-10-27 08:21:30 +08:00
|
|
|
event_count += 1
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_agent = sanitize_agent_name(agent)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Graph event #{event_count} received from agent: {safe_agent}")
|
2025-10-27 08:21:30 +08:00
|
|
|
|
2025-09-14 21:20:25 +08:00
|
|
|
if isinstance(event_data, dict):
|
|
|
|
|
if "__interrupt__" in event_data:
|
2025-10-27 08:21:30 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Processing interrupt event: "
|
2025-10-27 08:21:30 +08:00
|
|
|
f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, "
|
|
|
|
|
f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}"
|
|
|
|
|
)
|
2025-09-14 21:20:25 +08:00
|
|
|
yield _create_interrupt_event(thread_id, event_data)
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Dict event without interrupt, skipping")
|
2025-09-14 21:20:25 +08:00
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
message_chunk, message_metadata = cast(
|
|
|
|
|
tuple[BaseMessage, dict[str, Any]], event_data
|
|
|
|
|
)
|
2025-10-27 08:21:30 +08:00
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_node = sanitize_agent_name(message_metadata.get('langgraph_node', 'unknown'))
|
|
|
|
|
safe_step = sanitize_log_input(message_metadata.get('langgraph_step', 'unknown'))
|
2025-10-27 08:21:30 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Processing message chunk: "
|
2025-10-27 08:21:30 +08:00
|
|
|
f"type={type(message_chunk).__name__}, "
|
2025-10-27 20:57:23 +08:00
|
|
|
f"node={safe_node}, "
|
|
|
|
|
f"step={safe_step}"
|
2025-10-27 08:21:30 +08:00
|
|
|
)
|
2025-09-14 21:20:25 +08:00
|
|
|
|
|
|
|
|
async for event in _process_message_chunk(
|
|
|
|
|
message_chunk, message_metadata, thread_id, agent
|
|
|
|
|
):
|
|
|
|
|
yield event
|
2025-10-27 08:21:30 +08:00
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
|
2025-09-14 21:20:25 +08:00
|
|
|
except Exception as e:
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.exception(f"[{safe_thread_id}] Error during graph execution")
|
2025-09-14 21:20:25 +08:00
|
|
|
yield _make_event(
|
|
|
|
|
"error",
|
|
|
|
|
{
|
|
|
|
|
"thread_id": thread_id,
|
2025-09-16 10:01:24 +08:00
|
|
|
"error": "Error during graph execution",
|
2025-09-14 21:20:25 +08:00
|
|
|
},
|
2025-08-16 21:03:12 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
async def _astream_workflow_generator(
|
2025-06-11 11:10:02 +08:00
|
|
|
messages: List[dict],
|
2025-04-13 21:14:31 +08:00
|
|
|
thread_id: str,
|
2025-05-28 14:13:46 +08:00
|
|
|
resources: List[Resource],
|
2025-04-13 21:14:31 +08:00
|
|
|
max_plan_iterations: int,
|
|
|
|
|
max_step_num: int,
|
2025-05-17 22:23:52 -07:00
|
|
|
max_search_results: int,
|
2025-04-14 18:01:50 +08:00
|
|
|
auto_accepted_plan: bool,
|
2025-04-15 16:36:02 +08:00
|
|
|
interrupt_feedback: str,
|
2025-04-23 16:00:01 +08:00
|
|
|
mcp_settings: dict,
|
2025-06-07 20:48:39 +08:00
|
|
|
enable_background_investigation: bool,
|
|
|
|
|
report_style: ReportStyle,
|
2025-06-14 13:12:43 +08:00
|
|
|
enable_deep_thinking: bool,
|
2025-10-13 22:35:57 -07:00
|
|
|
enable_clarification: bool,
|
|
|
|
|
max_clarification_rounds: int,
|
2025-10-24 16:31:19 +08:00
|
|
|
locale: str = "en-US",
|
feat: implement tool-specific interrupts for create_react_agent (#572) (#659)
* feat: implement tool-specific interrupts for create_react_agent (#572)
Add selective tool interrupt capability allowing interrupts before specific tools
rather than all tools. Users can now configure which tools trigger interrupts via
the interrupt_before_tools parameter.
Changes:
- Create ToolInterceptor class to handle tool-specific interrupt logic
- Add interrupt_before_tools parameter to create_agent() function
- Extend Configuration with interrupt_before_tools field
- Add interrupt_before_tools to ChatRequest API
- Update nodes.py to pass interrupt configuration to agents
- Update app.py workflow to support tool interrupt configuration
- Add comprehensive unit tests for tool interceptor
Features:
- Selective tool interrupts: interrupt only specific tools by name
- Approval keywords: recognize user approval (approved, proceed, accept, etc.)
- Backward compatible: optional parameter, existing code unaffected
- Flexible: works with default tools and MCP-powered tools
- Works with existing resume mechanism for seamless workflow
Example usage:
request = ChatRequest(
messages=[...],
interrupt_before_tools=['db_tool', 'sensitive_api']
)
* test: add comprehensive integration tests for tool-specific interrupts (#572)
Add 24 integration tests covering all aspects of the tool interceptor feature:
Test Coverage:
- Agent creation with tool interrupts
- Configuration support (with/without interrupts)
- ChatRequest API integration
- Multiple tools with selective interrupts
- User approval/rejection flows
- Tool wrapping and functionality preservation
- Error handling and edge cases
- Approval keyword recognition
- Complex tool inputs
- Logging and monitoring
All tests pass with 100% coverage of tool interceptor functionality.
Tests verify:
✓ Selective tool interrupts work correctly
✓ Only specified tools trigger interrupts
✓ Non-matching tools execute normally
✓ User feedback is properly parsed
✓ Tool functionality is preserved after wrapping
✓ Error handling works as expected
✓ Configuration options are properly respected
✓ Logging provides useful debugging info
* fix: mock get_llm_by_type in agent creation test
Fix test_agent_creation_with_tool_interrupts which was failing because
get_llm_by_type() was being called before create_react_agent was mocked.
Changes:
- Add mock for get_llm_by_type in test
- Use context manager composition for multiple patches
- Test now passes and validates tool wrapping correctly
All 24 integration tests now pass successfully.
* refactor: use mock assertion methods for consistent and clearer error messages
Update integration tests to use mock assertion methods instead of direct
attribute checking for consistency and clearer error messages:
Changes:
- Replace 'assert mock_interrupt.called' with 'mock_interrupt.assert_called()'
- Replace 'assert not mock_interrupt.called' with 'mock_interrupt.assert_not_called()'
Benefits:
- Consistent with pytest-mock and unittest.mock best practices
- Clearer error messages when assertions fail
- Better IDE autocompletion support
- More professional test code
All 42 tests pass with improved assertion patterns.
* refactor: use default_factory for interrupt_before_tools consistency
Improve consistency between ChatRequest and Configuration implementations:
Changes:
- ChatRequest.interrupt_before_tools: Use Field(default_factory=list) instead of Optional[None]
- Remove unnecessary 'or []' conversion in app.py line 505
- Aligns with Configuration.interrupt_before_tools implementation pattern
- No functional changes - all tests still pass
Benefits:
- Consistent field definition across codebase
- Simpler and cleaner code
- Reduced chance of None/empty list bugs
- Better alignment with Pydantic best practices
All 42 tests passing.
* refactor: improve tool input formatting in interrupt messages
Enhance tool input representation for better readability in interrupt messages:
Changes:
- Add json import for better formatting
- Create _format_tool_input() static method with JSON serialization
- Use JSON formatting for dicts, lists, tuples with indent=2
- Fall back to str() for non-serializable types
- Handle None input specially (returns 'No input')
- Improve interrupt message formatting with better spacing
Benefits:
- Complex tool inputs now display as readable JSON
- Nested structures are properly indented and visible
- Better user experience when reviewing tool inputs before approval
- Handles edge cases gracefully with fallbacks
- Improved logging output for debugging
Example improvements:
Before: {'query': 'SELECT...', 'limit': 10, 'nested': {'key': 'value'}}
After:
{
"query": "SELECT...",
"limit": 10,
"nested": {
"key": "value"
}
}
All 42 tests still passing.
* test: add comprehensive unit tests for tool input formatting
2025-10-26 09:47:03 +08:00
|
|
|
interrupt_before_tools: Optional[List[str]] = None,
|
2025-04-13 21:14:31 +08:00
|
|
|
):
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_thread_id = sanitize_thread_id(thread_id)
|
|
|
|
|
safe_feedback = sanitize_log_input(interrupt_feedback) if interrupt_feedback else ""
|
2025-10-27 08:21:30 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] _astream_workflow_generator starting: "
|
2025-10-27 08:21:30 +08:00
|
|
|
f"messages_count={len(messages)}, "
|
|
|
|
|
f"auto_accepted_plan={auto_accepted_plan}, "
|
2025-10-27 20:57:23 +08:00
|
|
|
f"interrupt_feedback={safe_feedback}, "
|
2025-10-27 08:21:30 +08:00
|
|
|
f"interrupt_before_tools={interrupt_before_tools}"
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
# Process initial messages
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Processing {len(messages)} initial messages")
|
2025-08-16 21:03:12 +08:00
|
|
|
for message in messages:
|
|
|
|
|
if isinstance(message, dict) and "content" in message:
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_content = sanitize_user_content(message.get('content', ''))
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Sending initial message to client: {safe_content}")
|
2025-08-16 21:03:12 +08:00
|
|
|
_process_initial_messages(message, thread_id)
|
|
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Reconstructing clarification history")
|
2025-10-22 22:49:07 +08:00
|
|
|
clarification_history = reconstruct_clarification_history(messages)
|
|
|
|
|
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Building clarified topic from history")
|
2025-10-22 22:49:07 +08:00
|
|
|
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
|
|
|
|
clarification_history
|
|
|
|
|
)
|
|
|
|
|
latest_message_content = messages[-1]["content"] if messages else ""
|
|
|
|
|
clarified_research_topic = clarified_topic or latest_message_content
|
2025-10-27 20:57:23 +08:00
|
|
|
safe_topic = sanitize_user_content(clarified_research_topic)
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Clarified research topic: {safe_topic}")
|
2025-10-22 22:49:07 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
# Prepare workflow input
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Preparing workflow input")
|
2025-08-16 21:03:12 +08:00
|
|
|
workflow_input = {
|
2025-04-21 20:16:08 +08:00
|
|
|
"messages": messages,
|
|
|
|
|
"plan_iterations": 0,
|
|
|
|
|
"final_report": "",
|
|
|
|
|
"current_plan": None,
|
|
|
|
|
"observations": [],
|
|
|
|
|
"auto_accepted_plan": auto_accepted_plan,
|
2025-04-27 20:15:42 +08:00
|
|
|
"enable_background_investigation": enable_background_investigation,
|
2025-10-22 22:49:07 +08:00
|
|
|
"research_topic": latest_message_content,
|
|
|
|
|
"clarification_history": clarification_history,
|
|
|
|
|
"clarified_research_topic": clarified_research_topic,
|
2025-10-13 22:35:57 -07:00
|
|
|
"enable_clarification": enable_clarification,
|
|
|
|
|
"max_clarification_rounds": max_clarification_rounds,
|
2025-10-24 16:31:19 +08:00
|
|
|
"locale": locale,
|
2025-04-21 20:16:08 +08:00
|
|
|
}
|
2025-08-16 21:03:12 +08:00
|
|
|
|
2025-04-15 16:36:02 +08:00
|
|
|
if not auto_accepted_plan and interrupt_feedback:
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Creating resume command with interrupt_feedback: {safe_feedback}")
|
2025-04-15 16:36:02 +08:00
|
|
|
resume_msg = f"[{interrupt_feedback}]"
|
|
|
|
|
if messages:
|
2025-04-27 20:15:42 +08:00
|
|
|
resume_msg += f" {messages[-1]['content']}"
|
2025-08-16 21:03:12 +08:00
|
|
|
workflow_input = Command(resume=resume_msg)
|
|
|
|
|
|
|
|
|
|
# Prepare workflow config
|
2025-10-27 08:21:30 +08:00
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Preparing workflow config: "
|
2025-10-27 08:21:30 +08:00
|
|
|
f"max_plan_iterations={max_plan_iterations}, "
|
|
|
|
|
f"max_step_num={max_step_num}, "
|
|
|
|
|
f"report_style={report_style.value}, "
|
|
|
|
|
f"enable_deep_thinking={enable_deep_thinking}"
|
|
|
|
|
)
|
2025-08-16 21:03:12 +08:00
|
|
|
workflow_config = {
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"resources": resources,
|
|
|
|
|
"max_plan_iterations": max_plan_iterations,
|
|
|
|
|
"max_step_num": max_step_num,
|
|
|
|
|
"max_search_results": max_search_results,
|
|
|
|
|
"mcp_settings": mcp_settings,
|
|
|
|
|
"report_style": report_style.value,
|
|
|
|
|
"enable_deep_thinking": enable_deep_thinking,
|
feat: implement tool-specific interrupts for create_react_agent (#572) (#659)
* feat: implement tool-specific interrupts for create_react_agent (#572)
Add selective tool interrupt capability allowing interrupts before specific tools
rather than all tools. Users can now configure which tools trigger interrupts via
the interrupt_before_tools parameter.
Changes:
- Create ToolInterceptor class to handle tool-specific interrupt logic
- Add interrupt_before_tools parameter to create_agent() function
- Extend Configuration with interrupt_before_tools field
- Add interrupt_before_tools to ChatRequest API
- Update nodes.py to pass interrupt configuration to agents
- Update app.py workflow to support tool interrupt configuration
- Add comprehensive unit tests for tool interceptor
Features:
- Selective tool interrupts: interrupt only specific tools by name
- Approval keywords: recognize user approval (approved, proceed, accept, etc.)
- Backward compatible: optional parameter, existing code unaffected
- Flexible: works with default tools and MCP-powered tools
- Works with existing resume mechanism for seamless workflow
Example usage:
request = ChatRequest(
messages=[...],
interrupt_before_tools=['db_tool', 'sensitive_api']
)
* test: add comprehensive integration tests for tool-specific interrupts (#572)
Add 24 integration tests covering all aspects of the tool interceptor feature:
Test Coverage:
- Agent creation with tool interrupts
- Configuration support (with/without interrupts)
- ChatRequest API integration
- Multiple tools with selective interrupts
- User approval/rejection flows
- Tool wrapping and functionality preservation
- Error handling and edge cases
- Approval keyword recognition
- Complex tool inputs
- Logging and monitoring
All tests pass with 100% coverage of tool interceptor functionality.
Tests verify:
✓ Selective tool interrupts work correctly
✓ Only specified tools trigger interrupts
✓ Non-matching tools execute normally
✓ User feedback is properly parsed
✓ Tool functionality is preserved after wrapping
✓ Error handling works as expected
✓ Configuration options are properly respected
✓ Logging provides useful debugging info
* fix: mock get_llm_by_type in agent creation test
Fix test_agent_creation_with_tool_interrupts which was failing because
get_llm_by_type() was being called before create_react_agent was mocked.
Changes:
- Add mock for get_llm_by_type in test
- Use context manager composition for multiple patches
- Test now passes and validates tool wrapping correctly
All 24 integration tests now pass successfully.
* refactor: use mock assertion methods for consistent and clearer error messages
Update integration tests to use mock assertion methods instead of direct
attribute checking for consistency and clearer error messages:
Changes:
- Replace 'assert mock_interrupt.called' with 'mock_interrupt.assert_called()'
- Replace 'assert not mock_interrupt.called' with 'mock_interrupt.assert_not_called()'
Benefits:
- Consistent with pytest-mock and unittest.mock best practices
- Clearer error messages when assertions fail
- Better IDE autocompletion support
- More professional test code
All 42 tests pass with improved assertion patterns.
* refactor: use default_factory for interrupt_before_tools consistency
Improve consistency between ChatRequest and Configuration implementations:
Changes:
- ChatRequest.interrupt_before_tools: Use Field(default_factory=list) instead of Optional[None]
- Remove unnecessary 'or []' conversion in app.py line 505
- Aligns with Configuration.interrupt_before_tools implementation pattern
- No functional changes - all tests still pass
Benefits:
- Consistent field definition across codebase
- Simpler and cleaner code
- Reduced chance of None/empty list bugs
- Better alignment with Pydantic best practices
All 42 tests passing.
* refactor: improve tool input formatting in interrupt messages
Enhance tool input representation for better readability in interrupt messages:
Changes:
- Add json import for better formatting
- Create _format_tool_input() static method with JSON serialization
- Use JSON formatting for dicts, lists, tuples with indent=2
- Fall back to str() for non-serializable types
- Handle None input specially (returns 'No input')
- Improve interrupt message formatting with better spacing
Benefits:
- Complex tool inputs now display as readable JSON
- Nested structures are properly indented and visible
- Better user experience when reviewing tool inputs before approval
- Handles edge cases gracefully with fallbacks
- Improved logging output for debugging
Example improvements:
Before: {'query': 'SELECT...', 'limit': 10, 'nested': {'key': 'value'}}
After:
{
"query": "SELECT...",
"limit": 10,
"nested": {
"key": "value"
}
}
All 42 tests still passing.
* test: add comprehensive unit tests for tool input formatting
2025-10-26 09:47:03 +08:00
|
|
|
"interrupt_before_tools": interrupt_before_tools,
|
2025-08-16 21:03:12 +08:00
|
|
|
"recursion_limit": get_recursion_limit(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
|
|
|
|
|
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
2025-10-27 08:21:30 +08:00
|
|
|
|
|
|
|
|
logger.debug(
|
2025-10-27 20:57:23 +08:00
|
|
|
f"[{safe_thread_id}] Checkpoint configuration: "
|
2025-10-27 08:21:30 +08:00
|
|
|
f"saver_enabled={checkpoint_saver}, "
|
|
|
|
|
f"url_configured={bool(checkpoint_url)}"
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
# Handle checkpointer if configured
|
|
|
|
|
connection_kwargs = {
|
|
|
|
|
"autocommit": True,
|
|
|
|
|
"row_factory": "dict_row",
|
|
|
|
|
"prepare_threshold": 0,
|
|
|
|
|
}
|
|
|
|
|
if checkpoint_saver and checkpoint_url != "":
|
|
|
|
|
if checkpoint_url.startswith("postgresql://"):
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
|
2025-08-16 21:03:12 +08:00
|
|
|
async with AsyncConnectionPool(
|
|
|
|
|
checkpoint_url, kwargs=connection_kwargs
|
|
|
|
|
) as conn:
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
|
2025-08-16 21:03:12 +08:00
|
|
|
checkpointer = AsyncPostgresSaver(conn)
|
|
|
|
|
await checkpointer.setup()
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
|
2025-08-16 21:03:12 +08:00
|
|
|
graph.checkpointer = checkpointer
|
|
|
|
|
graph.store = in_memory_store
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
2025-08-16 21:03:12 +08:00
|
|
|
async for event in _stream_graph_events(
|
|
|
|
|
graph, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
yield event
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
2025-08-16 21:03:12 +08:00
|
|
|
|
|
|
|
|
if checkpoint_url.startswith("mongodb://"):
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
|
|
|
|
|
logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
|
2025-08-16 21:03:12 +08:00
|
|
|
async with AsyncMongoDBSaver.from_conn_string(
|
|
|
|
|
checkpoint_url
|
|
|
|
|
) as checkpointer:
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
|
2025-08-16 21:03:12 +08:00
|
|
|
graph.checkpointer = checkpointer
|
|
|
|
|
graph.store = in_memory_store
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
2025-08-16 21:03:12 +08:00
|
|
|
async for event in _stream_graph_events(
|
|
|
|
|
graph, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
yield event
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
2025-08-16 21:03:12 +08:00
|
|
|
else:
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
|
2025-08-16 21:03:12 +08:00
|
|
|
# Use graph without MongoDB checkpointer
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
2025-08-16 21:03:12 +08:00
|
|
|
async for event in _stream_graph_events(
|
|
|
|
|
graph, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
yield event
|
2025-10-27 20:57:23 +08:00
|
|
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
2025-04-13 21:14:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_event(event_type: str, data: dict[str, any]):
|
|
|
|
|
if data.get("content") == "":
|
|
|
|
|
data.pop("content")
|
2025-08-16 21:03:12 +08:00
|
|
|
# Ensure JSON serialization with proper encoding
|
|
|
|
|
try:
|
|
|
|
|
json_data = json.dumps(data, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
finish_reason = data.get("finish_reason", "")
|
|
|
|
|
chat_stream_message(
|
|
|
|
|
data.get("thread_id", ""),
|
|
|
|
|
f"event: {event_type}\ndata: {json_data}\n\n",
|
|
|
|
|
finish_reason,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return f"event: {event_type}\ndata: {json_data}\n\n"
|
|
|
|
|
except (TypeError, ValueError) as e:
|
|
|
|
|
logger.error(f"Error serializing event data: {e}")
|
|
|
|
|
# Return a safe error event
|
|
|
|
|
error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False)
|
|
|
|
|
return f"event: error\ndata: {error_data}\n\n"
|
2025-04-18 15:28:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/tts")
|
|
|
|
|
async def text_to_speech(request: TTSRequest):
|
|
|
|
|
"""Convert text to speech using volcengine TTS API."""
|
2025-08-16 21:03:12 +08:00
|
|
|
app_id = get_str_env("VOLCENGINE_TTS_APPID", "")
|
2025-06-18 14:13:05 +08:00
|
|
|
if not app_id:
|
|
|
|
|
raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set")
|
2025-08-16 21:03:12 +08:00
|
|
|
access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
2025-06-18 14:13:05 +08:00
|
|
|
if not access_token:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-18 15:28:31 +08:00
|
|
|
try:
|
2025-08-16 21:03:12 +08:00
|
|
|
cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
|
|
|
|
voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
|
2025-04-18 15:28:31 +08:00
|
|
|
|
|
|
|
|
tts_client = VolcengineTTS(
|
|
|
|
|
appid=app_id,
|
|
|
|
|
access_token=access_token,
|
|
|
|
|
cluster=cluster,
|
|
|
|
|
voice_type=voice_type,
|
|
|
|
|
)
|
|
|
|
|
# Call the TTS API
|
|
|
|
|
result = tts_client.text_to_speech(
|
|
|
|
|
text=request.text[:1024],
|
|
|
|
|
encoding=request.encoding,
|
|
|
|
|
speed_ratio=request.speed_ratio,
|
|
|
|
|
volume_ratio=request.volume_ratio,
|
|
|
|
|
pitch_ratio=request.pitch_ratio,
|
|
|
|
|
text_type=request.text_type,
|
|
|
|
|
with_frontend=request.with_frontend,
|
|
|
|
|
frontend_type=request.frontend_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not result["success"]:
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(result["error"]))
|
|
|
|
|
|
|
|
|
|
# Decode the base64 audio data
|
|
|
|
|
audio_data = base64.b64decode(result["audio_data"])
|
|
|
|
|
|
|
|
|
|
# Return the audio file
|
|
|
|
|
return Response(
|
|
|
|
|
content=audio_data,
|
|
|
|
|
media_type=f"audio/{request.encoding}",
|
|
|
|
|
headers={
|
|
|
|
|
"Content-Disposition": (
|
|
|
|
|
f"attachment; filename=tts_output.{request.encoding}"
|
|
|
|
|
)
|
|
|
|
|
},
|
|
|
|
|
)
|
2025-06-18 14:13:05 +08:00
|
|
|
|
2025-04-18 15:28:31 +08:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error in TTS endpoint: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-19 17:37:40 +08:00
|
|
|
|
|
|
|
|
|
2025-04-19 22:11:41 +08:00
|
|
|
@app.post("/api/podcast/generate")
|
|
|
|
|
async def generate_podcast(request: GeneratePodcastRequest):
|
2025-04-19 17:37:40 +08:00
|
|
|
try:
|
2025-04-19 22:11:41 +08:00
|
|
|
report_content = request.content
|
|
|
|
|
print(report_content)
|
2025-04-19 17:37:40 +08:00
|
|
|
workflow = build_podcast_graph()
|
|
|
|
|
final_state = workflow.invoke({"input": report_content})
|
|
|
|
|
audio_bytes = final_state["output"]
|
|
|
|
|
return Response(content=audio_bytes, media_type="audio/mp3")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during podcast generation: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-21 16:43:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/ppt/generate")
|
|
|
|
|
async def generate_ppt(request: GeneratePPTRequest):
|
|
|
|
|
try:
|
|
|
|
|
report_content = request.content
|
|
|
|
|
print(report_content)
|
|
|
|
|
workflow = build_ppt_graph()
|
|
|
|
|
final_state = workflow.invoke({"input": report_content})
|
|
|
|
|
generated_file_path = final_state["generated_file_path"]
|
|
|
|
|
with open(generated_file_path, "rb") as f:
|
|
|
|
|
ppt_bytes = f.read()
|
|
|
|
|
return Response(
|
|
|
|
|
content=ppt_bytes,
|
|
|
|
|
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during ppt generation: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-23 14:38:04 +08:00
|
|
|
|
|
|
|
|
|
2025-04-26 23:12:13 +08:00
|
|
|
@app.post("/api/prose/generate")
|
|
|
|
|
async def generate_prose(request: GenerateProseRequest):
|
|
|
|
|
try:
|
2025-06-03 11:50:54 +08:00
|
|
|
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
|
|
|
|
logger.info(f"Generating prose for prompt: {sanitized_prompt}")
|
2025-04-26 23:12:13 +08:00
|
|
|
workflow = build_prose_graph()
|
|
|
|
|
events = workflow.astream(
|
|
|
|
|
{
|
|
|
|
|
"content": request.prompt,
|
|
|
|
|
"option": request.option,
|
|
|
|
|
"command": request.command,
|
|
|
|
|
},
|
|
|
|
|
stream_mode="messages",
|
|
|
|
|
subgraphs=True,
|
|
|
|
|
)
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
(f"data: {event[0].content}\n\n" async for _, event in events),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during prose generation: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-26 23:12:13 +08:00
|
|
|
|
|
|
|
|
|
2025-06-08 19:41:59 +08:00
|
|
|
@app.post("/api/prompt/enhance")
|
|
|
|
|
async def enhance_prompt(request: EnhancePromptRequest):
|
|
|
|
|
try:
|
|
|
|
|
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
|
|
|
|
logger.info(f"Enhancing prompt: {sanitized_prompt}")
|
|
|
|
|
|
|
|
|
|
# Convert string report_style to ReportStyle enum
|
|
|
|
|
report_style = None
|
|
|
|
|
if request.report_style:
|
|
|
|
|
try:
|
|
|
|
|
# Handle both uppercase and lowercase input
|
|
|
|
|
style_mapping = {
|
|
|
|
|
"ACADEMIC": ReportStyle.ACADEMIC,
|
|
|
|
|
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
|
|
|
|
|
"NEWS": ReportStyle.NEWS,
|
|
|
|
|
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
|
2025-09-24 09:50:36 +08:00
|
|
|
"STRATEGIC_INVESTMENT": ReportStyle.STRATEGIC_INVESTMENT,
|
2025-06-08 19:41:59 +08:00
|
|
|
}
|
|
|
|
|
report_style = style_mapping.get(
|
2025-07-04 08:27:20 +08:00
|
|
|
request.report_style.upper(), ReportStyle.ACADEMIC
|
2025-06-08 19:41:59 +08:00
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
# If invalid style, default to ACADEMIC
|
|
|
|
|
report_style = ReportStyle.ACADEMIC
|
|
|
|
|
else:
|
|
|
|
|
report_style = ReportStyle.ACADEMIC
|
|
|
|
|
|
|
|
|
|
workflow = build_prompt_enhancer_graph()
|
|
|
|
|
final_state = workflow.invoke(
|
|
|
|
|
{
|
|
|
|
|
"prompt": request.prompt,
|
|
|
|
|
"context": request.context,
|
|
|
|
|
"report_style": report_style,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return {"result": final_state["output"]}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
|
|
|
|
|
2025-04-23 14:38:04 +08:00
|
|
|
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
|
|
|
|
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
|
|
|
|
"""Get information about an MCP server."""
|
2025-07-19 08:39:42 +08:00
|
|
|
# Check if MCP server configuration is enabled
|
2025-08-16 21:03:12 +08:00
|
|
|
if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
|
2025-07-19 08:39:42 +08:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=403,
|
2025-07-19 09:33:32 +08:00
|
|
|
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
2025-07-19 08:39:42 +08:00
|
|
|
)
|
|
|
|
|
|
2025-04-23 14:38:04 +08:00
|
|
|
try:
|
2025-05-01 23:16:57 +08:00
|
|
|
# Set default timeout with a longer value for this endpoint
|
|
|
|
|
timeout = 300 # Default to 300 seconds for this endpoint
|
2025-05-08 08:59:18 +08:00
|
|
|
|
2025-05-01 23:16:57 +08:00
|
|
|
# Use custom timeout from request if provided
|
|
|
|
|
if request.timeout_seconds is not None:
|
|
|
|
|
timeout = request.timeout_seconds
|
2025-05-08 08:59:18 +08:00
|
|
|
|
2025-04-23 14:38:04 +08:00
|
|
|
# Load tools from the MCP server using the utility function
|
|
|
|
|
tools = await load_mcp_tools(
|
2025-04-23 16:00:01 +08:00
|
|
|
server_type=request.transport,
|
2025-04-23 14:38:04 +08:00
|
|
|
command=request.command,
|
|
|
|
|
args=request.args,
|
|
|
|
|
url=request.url,
|
|
|
|
|
env=request.env,
|
2025-08-20 17:23:57 +08:00
|
|
|
headers=request.headers,
|
2025-05-01 23:16:57 +08:00
|
|
|
timeout_seconds=timeout,
|
2025-04-23 14:38:04 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create the response with tools
|
|
|
|
|
response = MCPServerMetadataResponse(
|
2025-04-23 16:00:01 +08:00
|
|
|
transport=request.transport,
|
2025-04-23 14:38:04 +08:00
|
|
|
command=request.command,
|
|
|
|
|
args=request.args,
|
|
|
|
|
url=request.url,
|
|
|
|
|
env=request.env,
|
2025-08-20 17:23:57 +08:00
|
|
|
headers=request.headers,
|
2025-04-23 14:38:04 +08:00
|
|
|
tools=tools,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
except Exception as e:
|
2025-06-18 14:13:05 +08:00
|
|
|
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-05-28 14:13:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/rag/config", response_model=RAGConfigResponse)
|
|
|
|
|
async def rag_config():
|
|
|
|
|
"""Get the config of the RAG."""
|
|
|
|
|
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
|
|
|
|
|
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
|
|
|
|
"""Get the resources of the RAG."""
|
|
|
|
|
retriever = build_retriever()
|
|
|
|
|
if retriever:
|
|
|
|
|
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
|
|
|
|
|
return RAGResourcesResponse(resources=[])
|
2025-06-14 13:12:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/config", response_model=ConfigResponse)
|
|
|
|
|
async def config():
|
|
|
|
|
"""Get the config of the server."""
|
|
|
|
|
return ConfigResponse(
|
|
|
|
|
rag=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER),
|
|
|
|
|
models=get_configured_llm_models(),
|
|
|
|
|
)
|