feat: Add configurable timeout support for MCP server operations with improved defaults

This commit is contained in:
Wang Hao
2025-05-01 23:16:57 +08:00
parent aeca7a5707
commit f30ec77d6f
3 changed files with 18 additions and 4 deletions

View File

@@ -283,6 +283,13 @@ async def generate_prose(request: GenerateProseRequest):
async def mcp_server_metadata(request: MCPServerMetadataRequest): async def mcp_server_metadata(request: MCPServerMetadataRequest):
"""Get information about an MCP server.""" """Get information about an MCP server."""
try: try:
# Set default timeout with a longer value for this endpoint
timeout = 300 # Default to 300 seconds for this endpoint
# Use custom timeout from request if provided
if request.timeout_seconds is not None:
timeout = request.timeout_seconds
# Load tools from the MCP server using the utility function # Load tools from the MCP server using the utility function
tools = await load_mcp_tools( tools = await load_mcp_tools(
server_type=request.transport, server_type=request.transport,
@@ -290,6 +297,7 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest):
args=request.args, args=request.args,
url=request.url, url=request.url,
env=request.env, env=request.env,
timeout_seconds=timeout,
) )
# Create the response with tools # Create the response with tools

View File

@@ -22,6 +22,9 @@ class MCPServerMetadataRequest(BaseModel):
None, description="The URL of the SSE server (for sse type)" None, description="The URL of the SSE server (for sse type)"
) )
env: Optional[Dict[str, str]] = Field(None, description="Environment variables") env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
timeout_seconds: Optional[int] = Field(
None, description="Optional custom timeout in seconds for the operation"
)
class MCPServerMetadataResponse(BaseModel): class MCPServerMetadataResponse(BaseModel):

View File

@@ -13,12 +13,13 @@ from mcp.client.sse import sse_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _get_tools_from_client_session(client_context_manager: Any) -> List: async def _get_tools_from_client_session(client_context_manager: Any, timeout_seconds: int = 10) -> List:
""" """
Helper function to get tools from a client session. Helper function to get tools from a client session.
Args: Args:
client_context_manager: A context manager that returns (read, write) functions client_context_manager: A context manager that returns (read, write) functions
timeout_seconds: Timeout in seconds for the read operation
Returns: Returns:
List of available tools from the MCP server List of available tools from the MCP server
@@ -28,7 +29,7 @@ async def _get_tools_from_client_session(client_context_manager: Any) -> List:
""" """
async with client_context_manager as (read, write): async with client_context_manager as (read, write):
async with ClientSession( async with ClientSession(
read, write, read_timeout_seconds=timedelta(seconds=10) read, write, read_timeout_seconds=timedelta(seconds=timeout_seconds)
) as session: ) as session:
# Initialize the connection # Initialize the connection
await session.initialize() await session.initialize()
@@ -43,6 +44,7 @@ async def load_mcp_tools(
args: Optional[List[str]] = None, args: Optional[List[str]] = None,
url: Optional[str] = None, url: Optional[str] = None,
env: Optional[Dict[str, str]] = None, env: Optional[Dict[str, str]] = None,
timeout_seconds: int = 60, # Longer default timeout for first-time executions
) -> List: ) -> List:
""" """
Load tools from an MCP server. Load tools from an MCP server.
@@ -53,6 +55,7 @@ async def load_mcp_tools(
args: Command arguments (for stdio type) args: Command arguments (for stdio type)
url: The URL of the SSE server (for sse type) url: The URL of the SSE server (for sse type)
env: Environment variables env: Environment variables
timeout_seconds: Timeout in seconds (default: 60 for first-time executions)
Returns: Returns:
List of available tools from the MCP server List of available tools from the MCP server
@@ -73,7 +76,7 @@ async def load_mcp_tools(
env=env, # Optional environment variables env=env, # Optional environment variables
) )
return await _get_tools_from_client_session(stdio_client(server_params)) return await _get_tools_from_client_session(stdio_client(server_params), timeout_seconds)
elif server_type == "sse": elif server_type == "sse":
if not url: if not url:
@@ -81,7 +84,7 @@ async def load_mcp_tools(
status_code=400, detail="URL is required for sse type" status_code=400, detail="URL is required for sse type"
) )
return await _get_tools_from_client_session(sse_client(url=url)) return await _get_tools_from_client_session(sse_client(url=url), timeout_seconds)
else: else:
raise HTTPException( raise HTTPException(