feat: may_ask (#981)

* feat: u may ask

* chore: adjust code according to CR

* chore: adjust code according to CR

* ut: test for suggestions.py

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
null4536251
2026-03-06 22:39:58 +08:00
committed by GitHub
parent 2e90101be8
commit 9d2144d431
10 changed files with 462 additions and 35 deletions

View File

@@ -7,7 +7,16 @@ from fastapi import FastAPI
from src.config.app_config import get_app_config
from src.gateway.config import get_gateway_config
from src.gateway.routers import agents, artifacts, mcp, memory, models, skills, uploads
from src.gateway.routers import (
agents,
artifacts,
mcp,
memory,
models,
skills,
suggestions,
uploads,
)
# Configure logging
logging.basicConfig(
@@ -104,6 +113,10 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
"name": "agents",
"description": "Create and manage custom agents with per-agent config and prompts",
},
{
"name": "suggestions",
"description": "Generate follow-up question suggestions for conversations",
},
{
"name": "health",
"description": "Health check and system status endpoints",
@@ -135,6 +148,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Agents API is mounted at /api/agents
app.include_router(agents.router)
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict:
"""Health check endpoint.

View File

@@ -1,3 +1,3 @@
from . import artifacts, mcp, models, skills, uploads
from . import artifacts, mcp, models, skills, suggestions, uploads
__all__ = ["artifacts", "mcp", "models", "skills", "uploads"]
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]

View File

@@ -0,0 +1,114 @@
import json
import logging
from fastapi import APIRouter
from pydantic import BaseModel, Field
from src.models import create_chat_model
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["suggestions"])
class SuggestionMessage(BaseModel):
role: str = Field(..., description="Message role: user|assistant")
content: str = Field(..., description="Message content as plain text")
class SuggestionsRequest(BaseModel):
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
model_name: str | None = Field(default=None, description="Optional model override")
class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
return "\n".join(lines[1:-1]).strip()
return stripped
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
return None
candidate = candidate[start : end + 1]
try:
data = json.loads(candidate)
except Exception:
return None
if not isinstance(data, list):
return None
out: list[str] = []
for item in data:
if not isinstance(item, str):
continue
s = item.strip()
if not s:
continue
out.append(s)
return out
def _format_conversation(messages: list[SuggestionMessage]) -> str:
parts: list[str] = []
for m in messages:
role = m.role.strip().lower()
if role in ("user", "human"):
parts.append(f"User: {m.content.strip()}")
elif role in ("assistant", "ai"):
parts.append(f"Assistant: {m.content.strip()}")
else:
parts.append(f"{m.role}: {m.content.strip()}")
return "\n".join(parts).strip()
@router.post(
"/threads/{thread_id}/suggestions",
response_model=SuggestionsResponse,
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
prompt = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n\n"
"Conversation:\n"
f"{conversation}\n"
).format(n=n, conversation=conversation)
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = model.invoke(prompt)
raw = str(response.content or "")
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
cleaned = cleaned[:n]
return SuggestionsResponse(suggestions=cleaned)
except Exception as exc:
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
return SuggestionsResponse(suggestions=[])

View File

@@ -0,0 +1,66 @@
import asyncio
from unittest.mock import MagicMock
from src.gateway.routers import suggestions
def test_strip_markdown_code_fence_removes_wrapping():
text = "```json\n[\"a\"]\n```"
assert suggestions._strip_markdown_code_fence(text) == "[\"a\"]"
def test_strip_markdown_code_fence_no_fence_keeps_content():
text = " [\"a\"] "
assert suggestions._strip_markdown_code_fence(text) == "[\"a\"]"
def test_parse_json_string_list_filters_invalid_items():
text = "```json\n[\"a\", \" \", 1, \"b\"]\n```"
assert suggestions._parse_json_string_list(text) == ["a", "b"]
def test_parse_json_string_list_rejects_non_list():
text = "{\"a\": 1}"
assert suggestions._parse_json_string_list(text) is None
def test_format_conversation_formats_roles():
messages = [
suggestions.SuggestionMessage(role="User", content="Hi"),
suggestions.SuggestionMessage(role="assistant", content="Hello"),
suggestions.SuggestionMessage(role="system", content="note"),
]
assert suggestions._format_conversation(messages) == "User: Hi\nAssistant: Hello\nsystem: note"
def test_generate_suggestions_parses_and_limits(monkeypatch):
req = suggestions.SuggestionsRequest(
messages=[
suggestions.SuggestionMessage(role="user", content="Hi"),
suggestions.SuggestionMessage(role="assistant", content="Hello"),
],
n=3,
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content="```json\n[\"Q1\", \"Q2\", \"Q3\", \"Q4\"]\n```")
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
assert result.suggestions == ["Q1", "Q2", "Q3"]
def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
req = suggestions.SuggestionsRequest(
messages=[suggestions.SuggestionMessage(role="user", content="Hi")],
n=2,
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.side_effect = RuntimeError("boom")
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
assert result.suggestions == []

View File

@@ -151,6 +151,7 @@ export default function AgentChatPage() {
<InputBox
className={cn("bg-background/5 w-full -translate-y-4")}
isNewThread={isNewThread}
threadId={threadId}
autoFocus={isNewThread}
status={thread.isLoading ? "streaming" : "ready"}
context={settings.context}

View File

@@ -119,6 +119,7 @@ export default function ChatPage() {
<InputBox
className={cn("bg-background/5 w-full -translate-y-4")}
isNewThread={isNewThread}
threadId={threadId}
autoFocus={isNewThread}
status={thread.isLoading ? "streaming" : "ready"}
context={settings.context}

View File

@@ -9,6 +9,7 @@ import {
PlusIcon,
SparklesIcon,
RocketIcon,
XIcon,
ZapIcon,
} from "lucide-react";
import { useSearchParams } from "next/navigation";
@@ -16,6 +17,7 @@ import {
useCallback,
useEffect,
useMemo,
useRef,
useState,
type ComponentProps,
} from "react";
@@ -38,15 +40,26 @@ import {
usePromptInputController,
type PromptInputMessage,
} from "@/components/ai-elements/prompt-input";
import { Button } from "@/components/ui/button";
import { ConfettiButton } from "@/components/ui/confetti-button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
DropdownMenuGroup,
DropdownMenuLabel,
DropdownMenuSeparator,
} from "@/components/ui/dropdown-menu";
import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks";
import { useModels } from "@/core/models/hooks";
import type { AgentThreadContext } from "@/core/threads";
import { textOfMessage } from "@/core/threads/utils";
import { cn } from "@/lib/utils";
import {
@@ -66,6 +79,7 @@ import {
DropdownMenuTrigger,
} from "../ui/dropdown-menu";
import { useThread } from "./messages/context";
import { ModeHoverGuide } from "./mode-hover-guide";
import { Tooltip } from "./tooltip";
@@ -92,6 +106,7 @@ export function InputBox({
context,
extraHeader,
isNewThread,
threadId,
initialValue,
onContextChange,
onSubmit,
@@ -110,6 +125,7 @@ export function InputBox({
};
extraHeader?: React.ReactNode;
isNewThread?: boolean;
threadId: string;
initialValue?: string;
onContextChange?: (
context: Omit<
@@ -127,6 +143,20 @@ export function InputBox({
const searchParams = useSearchParams();
const [modelDialogOpen, setModelDialogOpen] = useState(false);
const { models } = useModels();
const { thread, isMock } = useThread();
const { textInput } = usePromptInputController();
const promptRootRef = useRef<HTMLDivElement | null>(null);
const [followups, setFollowups] = useState<string[]>([]);
const [followupsHidden, setFollowupsHidden] = useState(false);
const [followupsLoading, setFollowupsLoading] = useState(false);
const lastGeneratedForAiIdRef = useRef<string | null>(null);
const wasStreamingRef = useRef(false);
const [confirmOpen, setConfirmOpen] = useState(false);
const [pendingSuggestion, setPendingSuggestion] = useState<string | null>(
null,
);
useEffect(() => {
if (models.length === 0) {
@@ -213,43 +243,168 @@ export function InputBox({
if (!message.text) {
return;
}
setFollowups([]);
setFollowupsHidden(false);
setFollowupsLoading(false);
onSubmit?.(message);
},
[onSubmit, onStop, status],
);
const requestFormSubmit = useCallback(() => {
const form = promptRootRef.current?.querySelector("form");
form?.requestSubmit();
}, []);
const handleFollowupClick = useCallback(
(suggestion: string) => {
if (status === "streaming") {
return;
}
const current = (textInput.value ?? "").trim();
if (current) {
setPendingSuggestion(suggestion);
setConfirmOpen(true);
return;
}
textInput.setInput(suggestion);
setFollowupsHidden(true);
setTimeout(() => requestFormSubmit(), 0);
},
[requestFormSubmit, status, textInput],
);
const confirmReplaceAndSend = useCallback(() => {
if (!pendingSuggestion) {
setConfirmOpen(false);
return;
}
textInput.setInput(pendingSuggestion);
setFollowupsHidden(true);
setConfirmOpen(false);
setPendingSuggestion(null);
setTimeout(() => requestFormSubmit(), 0);
}, [pendingSuggestion, requestFormSubmit, textInput]);
const confirmAppendAndSend = useCallback(() => {
if (!pendingSuggestion) {
setConfirmOpen(false);
return;
}
const current = (textInput.value ?? "").trim();
const next = current ? `${current}\n${pendingSuggestion}` : pendingSuggestion;
textInput.setInput(next);
setFollowupsHidden(true);
setConfirmOpen(false);
setPendingSuggestion(null);
setTimeout(() => requestFormSubmit(), 0);
}, [pendingSuggestion, requestFormSubmit, textInput]);
useEffect(() => {
const streaming = status === "streaming";
const wasStreaming = wasStreamingRef.current;
wasStreamingRef.current = streaming;
if (!wasStreaming || streaming) {
return;
}
if (disabled || isMock) {
return;
}
const lastAi = [...thread.messages].reverse().find((m) => m.type === "ai");
const lastAiId = lastAi?.id ?? null;
if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) {
return;
}
lastGeneratedForAiIdRef.current = lastAiId;
const recent = thread.messages
.filter((m) => m.type === "human" || m.type === "ai")
.map((m) => {
const role = m.type === "human" ? "user" : "assistant";
const content = textOfMessage(m) ?? "";
return { role, content };
})
.filter((m) => m.content.trim().length > 0)
.slice(-6);
if (recent.length === 0) {
return;
}
const controller = new AbortController();
setFollowupsHidden(false);
setFollowupsLoading(true);
setFollowups([]);
fetch(`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
messages: recent,
n: 3,
model_name: context.model_name ?? undefined,
}),
signal: controller.signal,
})
.then(async (res) => {
if (!res.ok) {
return { suggestions: [] as string[] };
}
return (await res.json()) as { suggestions?: string[] };
})
.then((data) => {
const suggestions = (data.suggestions ?? [])
.map((s) => (typeof s === "string" ? s.trim() : ""))
.filter((s) => s.length > 0)
.slice(0, 5);
setFollowups(suggestions);
})
.catch(() => {
setFollowups([]);
})
.finally(() => {
setFollowupsLoading(false);
});
return () => controller.abort();
}, [context.model_name, disabled, isMock, status, thread.messages, threadId]);
return (
<PromptInput
className={cn(
"bg-background/85 rounded-2xl backdrop-blur-sm transition-all duration-300 ease-out *:data-[slot='input-group']:rounded-2xl",
className,
)}
disabled={disabled}
globalDrop
multiple
onSubmit={handleSubmit}
{...props}
>
{extraHeader && (
<div className="absolute top-0 right-0 left-0 z-10">
<div className="absolute right-0 bottom-0 left-0 flex items-center justify-center">
{extraHeader}
<div ref={promptRootRef} className="relative">
<PromptInput
className={cn(
"bg-background/85 rounded-2xl backdrop-blur-sm transition-all duration-300 ease-out *:data-[slot='input-group']:rounded-2xl",
className,
)}
disabled={disabled}
globalDrop
multiple
onSubmit={handleSubmit}
{...props}
>
{extraHeader && (
<div className="absolute top-0 right-0 left-0 z-10">
<div className="absolute right-0 bottom-0 left-0 flex items-center justify-center">
{extraHeader}
</div>
</div>
</div>
)}
<PromptInputAttachments>
{(attachment) => <PromptInputAttachment data={attachment} />}
</PromptInputAttachments>
<PromptInputBody className="absolute top-0 right-0 left-0 z-3">
<PromptInputTextarea
className={cn("size-full")}
disabled={disabled}
placeholder={t.inputBox.placeholder}
autoFocus={autoFocus}
defaultValue={initialValue}
/>
</PromptInputBody>
<PromptInputFooter className="flex">
<PromptInputTools>
)}
<PromptInputAttachments>
{(attachment) => <PromptInputAttachment data={attachment} />}
</PromptInputAttachments>
<PromptInputBody className="absolute top-0 right-0 left-0 z-3">
<PromptInputTextarea
className={cn("size-full")}
disabled={disabled}
placeholder={t.inputBox.placeholder}
autoFocus={autoFocus}
defaultValue={initialValue}
/>
</PromptInputBody>
<PromptInputFooter className="flex">
<PromptInputTools>
{/* TODO: Add more connectors here
<PromptInputActionMenu>
<PromptInputActionMenuTrigger className="px-2!" />
@@ -588,7 +743,65 @@ export function InputBox({
{!isNewThread && (
<div className="bg-background absolute right-0 -bottom-[17px] left-0 z-0 h-4"></div>
)}
</PromptInput>
</PromptInput>
{!disabled &&
!isNewThread &&
!followupsHidden &&
(followupsLoading || followups.length > 0) && (
<div className="absolute right-0 -top-20 left-0 z-20 flex items-center justify-center">
<div className="flex items-center gap-2">
{followupsLoading ? (
<div className="text-muted-foreground bg-background/80 rounded-full border px-4 py-2 text-xs backdrop-blur-sm">
{t.inputBox.followupLoading}
</div>
) : (
<Suggestions className="min-h-16 w-fit items-start">
{followups.map((s) => (
<Suggestion
key={s}
suggestion={s}
onClick={() => handleFollowupClick(s)}
/>
))}
<Button
aria-label={t.common.close}
className="text-muted-foreground cursor-pointer rounded-full px-3 text-xs font-normal"
variant="outline"
size="sm"
type="button"
onClick={() => setFollowupsHidden(true)}
>
<XIcon className="size-4" />
</Button>
</Suggestions>
)}
</div>
</div>
)}
<Dialog open={confirmOpen} onOpenChange={setConfirmOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>{t.inputBox.followupConfirmTitle}</DialogTitle>
<DialogDescription>
{t.inputBox.followupConfirmDescription}
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setConfirmOpen(false)}>
{t.common.cancel}
</Button>
<Button variant="secondary" onClick={confirmAppendAndSend}>
{t.inputBox.followupConfirmAppend}
</Button>
<Button onClick={confirmReplaceAndSend}>
{t.inputBox.followupConfirmReplace}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
);
}

View File

@@ -96,6 +96,12 @@ export const enUS: Translations = {
searchModels: "Search models...",
surpriseMe: "Surprise",
surpriseMePrompt: "Surprise me",
followupLoading: "Generating follow-up questions...",
followupConfirmTitle: "Send suggestion?",
followupConfirmDescription:
"You already have text in the input. Choose how to send it.",
followupConfirmAppend: "Append & send",
followupConfirmReplace: "Replace & send",
suggestions: [
{
suggestion: "Write",

View File

@@ -76,6 +76,11 @@ export interface Translations {
searchModels: string;
surpriseMe: string;
surpriseMePrompt: string;
followupLoading: string;
followupConfirmTitle: string;
followupConfirmDescription: string;
followupConfirmAppend: string;
followupConfirmReplace: string;
suggestions: {
suggestion: string;
prompt: string;

View File

@@ -92,6 +92,11 @@ export const zhCN: Translations = {
searchModels: "搜索模型...",
surpriseMe: "小惊喜",
surpriseMePrompt: "给我一个小惊喜吧",
followupLoading: "正在生成可能的后续问题...",
followupConfirmTitle: "发送建议问题?",
followupConfirmDescription: "当前输入框已有内容,选择发送方式。",
followupConfirmAppend: "追加并发送",
followupConfirmReplace: "替换并发送",
suggestions: [
{
suggestion: "写作",