feat: may_ask (#981)

* feat: u may ask

* chore: adjust code according to CR

* chore: adjust code according to CR

* ut: test for suggestions.py

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
null4536251
2026-03-06 22:39:58 +08:00
committed by GitHub
parent 2e90101be8
commit 9d2144d431
10 changed files with 462 additions and 35 deletions

View File

@@ -7,7 +7,16 @@ from fastapi import FastAPI
from src.config.app_config import get_app_config
from src.gateway.config import get_gateway_config
from src.gateway.routers import agents, artifacts, mcp, memory, models, skills, uploads
from src.gateway.routers import (
agents,
artifacts,
mcp,
memory,
models,
skills,
suggestions,
uploads,
)
# Configure logging
logging.basicConfig(
@@ -104,6 +113,10 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
"name": "agents",
"description": "Create and manage custom agents with per-agent config and prompts",
},
{
"name": "suggestions",
"description": "Generate follow-up question suggestions for conversations",
},
{
"name": "health",
"description": "Health check and system status endpoints",
@@ -135,6 +148,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Agents API is mounted at /api/agents
app.include_router(agents.router)
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict:
"""Health check endpoint.

View File

@@ -1,3 +1,3 @@
from . import artifacts, mcp, models, skills, uploads
from . import artifacts, mcp, models, skills, suggestions, uploads
__all__ = ["artifacts", "mcp", "models", "skills", "uploads"]
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]

View File

@@ -0,0 +1,114 @@
import json
import logging
from fastapi import APIRouter
from pydantic import BaseModel, Field
from src.models import create_chat_model
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["suggestions"])
class SuggestionMessage(BaseModel):
role: str = Field(..., description="Message role: user|assistant")
content: str = Field(..., description="Message content as plain text")
class SuggestionsRequest(BaseModel):
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
model_name: str | None = Field(default=None, description="Optional model override")
class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
return "\n".join(lines[1:-1]).strip()
return stripped
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
return None
candidate = candidate[start : end + 1]
try:
data = json.loads(candidate)
except Exception:
return None
if not isinstance(data, list):
return None
out: list[str] = []
for item in data:
if not isinstance(item, str):
continue
s = item.strip()
if not s:
continue
out.append(s)
return out
def _format_conversation(messages: list[SuggestionMessage]) -> str:
parts: list[str] = []
for m in messages:
role = m.role.strip().lower()
if role in ("user", "human"):
parts.append(f"User: {m.content.strip()}")
elif role in ("assistant", "ai"):
parts.append(f"Assistant: {m.content.strip()}")
else:
parts.append(f"{m.role}: {m.content.strip()}")
return "\n".join(parts).strip()
@router.post(
"/threads/{thread_id}/suggestions",
response_model=SuggestionsResponse,
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
prompt = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n\n"
"Conversation:\n"
f"{conversation}\n"
).format(n=n, conversation=conversation)
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = model.invoke(prompt)
raw = str(response.content or "")
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
cleaned = cleaned[:n]
return SuggestionsResponse(suggestions=cleaned)
except Exception as exc:
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
return SuggestionsResponse(suggestions=[])