From 9d2144d431a81960936fb9ae313a34f698c4d236 Mon Sep 17 00:00:00 2001 From: null4536251 <77783126+null4536251@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:39:58 +0800 Subject: [PATCH] feat: may_ask (#981) * feat: u may ask * chore: adjust code according to CR * chore: adjust code according to CR * ut: test for suggestions.py --------- Co-authored-by: Willem Jiang --- backend/src/gateway/app.py | 18 +- backend/src/gateway/routers/__init__.py | 4 +- backend/src/gateway/routers/suggestions.py | 114 +++++++ backend/tests/test_suggestions_router.py | 66 +++++ .../[agent_name]/chats/[thread_id]/page.tsx | 1 + .../app/workspace/chats/[thread_id]/page.tsx | 1 + .../src/components/workspace/input-box.tsx | 277 ++++++++++++++++-- frontend/src/core/i18n/locales/en-US.ts | 6 + frontend/src/core/i18n/locales/types.ts | 5 + frontend/src/core/i18n/locales/zh-CN.ts | 5 + 10 files changed, 462 insertions(+), 35 deletions(-) create mode 100644 backend/src/gateway/routers/suggestions.py create mode 100644 backend/tests/test_suggestions_router.py diff --git a/backend/src/gateway/app.py b/backend/src/gateway/app.py index 48c1bed..fd2e51a 100644 --- a/backend/src/gateway/app.py +++ b/backend/src/gateway/app.py @@ -7,7 +7,16 @@ from fastapi import FastAPI from src.config.app_config import get_app_config from src.gateway.config import get_gateway_config -from src.gateway.routers import agents, artifacts, mcp, memory, models, skills, uploads +from src.gateway.routers import ( + agents, + artifacts, + mcp, + memory, + models, + skills, + suggestions, + uploads, +) # Configure logging logging.basicConfig( @@ -104,6 +113,10 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an "name": "agents", "description": "Create and manage custom agents with per-agent config and prompts", }, + { + "name": "suggestions", + "description": "Generate follow-up question suggestions for conversations", + }, { "name": "health", "description": "Health check and system status endpoints", @@ -135,6 +148,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an # Agents API is mounted at /api/agents app.include_router(agents.router) + # Suggestions API is mounted at /api/threads/{thread_id}/suggestions + app.include_router(suggestions.router) + @app.get("/health", tags=["health"]) async def health_check() -> dict: """Health check endpoint. diff --git a/backend/src/gateway/routers/__init__.py b/backend/src/gateway/routers/__init__.py index 62a0bd2..0652330 100644 --- a/backend/src/gateway/routers/__init__.py +++ b/backend/src/gateway/routers/__init__.py @@ -1,3 +1,3 @@ -from . import artifacts, mcp, models, skills, uploads +from . import artifacts, mcp, models, skills, suggestions, uploads -__all__ = ["artifacts", "mcp", "models", "skills", "uploads"] +__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"] diff --git a/backend/src/gateway/routers/suggestions.py b/backend/src/gateway/routers/suggestions.py new file mode 100644 index 0000000..031f3bc --- /dev/null +++ b/backend/src/gateway/routers/suggestions.py @@ -0,0 +1,114 @@ +import json +import logging + +from fastapi import APIRouter +from pydantic import BaseModel, Field + +from src.models import create_chat_model + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api", tags=["suggestions"]) + + +class SuggestionMessage(BaseModel): + role: str = Field(..., description="Message role: user|assistant") + content: str = Field(..., description="Message content as plain text") + + +class SuggestionsRequest(BaseModel): + messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages") + n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate") + model_name: str | None = Field(default=None, description="Optional model override") + + +class SuggestionsResponse(BaseModel): + suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions") + + +def _strip_markdown_code_fence(text: str) -> str: + stripped = text.strip() + if not stripped.startswith("```"): + return stripped + lines = stripped.splitlines() + if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"): + return "\n".join(lines[1:-1]).strip() + return stripped + + +def _parse_json_string_list(text: str) -> list[str] | None: + candidate = _strip_markdown_code_fence(text) + start = candidate.find("[") + end = candidate.rfind("]") + if start == -1 or end == -1 or end <= start: + return None + candidate = candidate[start : end + 1] + try: + data = json.loads(candidate) + except Exception: + return None + if not isinstance(data, list): + return None + out: list[str] = [] + for item in data: + if not isinstance(item, str): + continue + s = item.strip() + if not s: + continue + out.append(s) + return out + + +def _format_conversation(messages: list[SuggestionMessage]) -> str: + parts: list[str] = [] + for m in messages: + role = m.role.strip().lower() + if role in ("user", "human"): + parts.append(f"User: {m.content.strip()}") + elif role in ("assistant", "ai"): + parts.append(f"Assistant: {m.content.strip()}") + else: + parts.append(f"{m.role}: {m.content.strip()}") + return "\n".join(parts).strip() + + +@router.post( + "/threads/{thread_id}/suggestions", + response_model=SuggestionsResponse, + summary="Generate Follow-up Questions", + description="Generate short follow-up questions a user might ask next, based on recent conversation context.", +) +async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse: + if not request.messages: + return SuggestionsResponse(suggestions=[]) + + n = request.n + conversation = _format_conversation(request.messages) + if not conversation: + return SuggestionsResponse(suggestions=[]) + + prompt = ( + "You are generating follow-up questions to help the user continue the conversation.\n" + f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n" + "Requirements:\n" + "- Questions must be relevant to the conversation.\n" + "- Questions must be written in the same language as the user.\n" + "- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n" + "- Do NOT include numbering, markdown, or any extra text.\n" + "- Output MUST be a JSON array of strings only.\n\n" + "Conversation:\n" + f"{conversation}\n" + ).format(n=n, conversation=conversation) + + try: + model = create_chat_model(name=request.model_name, thinking_enabled=False) + response = model.invoke(prompt) + raw = str(response.content or "") + suggestions = _parse_json_string_list(raw) or [] + cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()] + cleaned = cleaned[:n] + return SuggestionsResponse(suggestions=cleaned) + except Exception as exc: + logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc) + return SuggestionsResponse(suggestions=[]) diff --git a/backend/tests/test_suggestions_router.py b/backend/tests/test_suggestions_router.py new file mode 100644 index 0000000..538d3cb --- /dev/null +++ b/backend/tests/test_suggestions_router.py @@ -0,0 +1,66 @@ +import asyncio +from unittest.mock import MagicMock + +from src.gateway.routers import suggestions + + +def test_strip_markdown_code_fence_removes_wrapping(): + text = "```json\n[\"a\"]\n```" + assert suggestions._strip_markdown_code_fence(text) == "[\"a\"]" + + +def test_strip_markdown_code_fence_no_fence_keeps_content(): + text = " [\"a\"] " + assert suggestions._strip_markdown_code_fence(text) == "[\"a\"]" + + +def test_parse_json_string_list_filters_invalid_items(): + text = "```json\n[\"a\", \" \", 1, \"b\"]\n```" + assert suggestions._parse_json_string_list(text) == ["a", "b"] + + +def test_parse_json_string_list_rejects_non_list(): + text = "{\"a\": 1}" + assert suggestions._parse_json_string_list(text) is None + + +def test_format_conversation_formats_roles(): + messages = [ + suggestions.SuggestionMessage(role="User", content="Hi"), + suggestions.SuggestionMessage(role="assistant", content="Hello"), + suggestions.SuggestionMessage(role="system", content="note"), + ] + assert suggestions._format_conversation(messages) == "User: Hi\nAssistant: Hello\nsystem: note" + + +def test_generate_suggestions_parses_and_limits(monkeypatch): + req = suggestions.SuggestionsRequest( + messages=[ + suggestions.SuggestionMessage(role="user", content="Hi"), + suggestions.SuggestionMessage(role="assistant", content="Hello"), + ], + n=3, + model_name=None, + ) + fake_model = MagicMock() + fake_model.invoke.return_value = MagicMock(content="```json\n[\"Q1\", \"Q2\", \"Q3\", \"Q4\"]\n```") + monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) + + result = asyncio.run(suggestions.generate_suggestions("t1", req)) + + assert result.suggestions == ["Q1", "Q2", "Q3"] + + +def test_generate_suggestions_returns_empty_on_model_error(monkeypatch): + req = suggestions.SuggestionsRequest( + messages=[suggestions.SuggestionMessage(role="user", content="Hi")], + n=2, + model_name=None, + ) + fake_model = MagicMock() + fake_model.invoke.side_effect = RuntimeError("boom") + monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) + + result = asyncio.run(suggestions.generate_suggestions("t1", req)) + + assert result.suggestions == [] \ No newline at end of file diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index 6bbb597..0a5236d 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -151,6 +151,7 @@ export default function AgentChatPage() { (null); + + const [followups, setFollowups] = useState([]); + const [followupsHidden, setFollowupsHidden] = useState(false); + const [followupsLoading, setFollowupsLoading] = useState(false); + const lastGeneratedForAiIdRef = useRef(null); + const wasStreamingRef = useRef(false); + + const [confirmOpen, setConfirmOpen] = useState(false); + const [pendingSuggestion, setPendingSuggestion] = useState( + null, + ); useEffect(() => { if (models.length === 0) { @@ -213,43 +243,168 @@ export function InputBox({ if (!message.text) { return; } + setFollowups([]); + setFollowupsHidden(false); + setFollowupsLoading(false); onSubmit?.(message); }, [onSubmit, onStop, status], ); + + const requestFormSubmit = useCallback(() => { + const form = promptRootRef.current?.querySelector("form"); + form?.requestSubmit(); + }, []); + + const handleFollowupClick = useCallback( + (suggestion: string) => { + if (status === "streaming") { + return; + } + const current = (textInput.value ?? "").trim(); + if (current) { + setPendingSuggestion(suggestion); + setConfirmOpen(true); + return; + } + textInput.setInput(suggestion); + setFollowupsHidden(true); + setTimeout(() => requestFormSubmit(), 0); + }, + [requestFormSubmit, status, textInput], + ); + + const confirmReplaceAndSend = useCallback(() => { + if (!pendingSuggestion) { + setConfirmOpen(false); + return; + } + textInput.setInput(pendingSuggestion); + setFollowupsHidden(true); + setConfirmOpen(false); + setPendingSuggestion(null); + setTimeout(() => requestFormSubmit(), 0); + }, [pendingSuggestion, requestFormSubmit, textInput]); + + const confirmAppendAndSend = useCallback(() => { + if (!pendingSuggestion) { + setConfirmOpen(false); + return; + } + const current = (textInput.value ?? "").trim(); + const next = current ? `${current}\n${pendingSuggestion}` : pendingSuggestion; + textInput.setInput(next); + setFollowupsHidden(true); + setConfirmOpen(false); + setPendingSuggestion(null); + setTimeout(() => requestFormSubmit(), 0); + }, [pendingSuggestion, requestFormSubmit, textInput]); + + useEffect(() => { + const streaming = status === "streaming"; + const wasStreaming = wasStreamingRef.current; + wasStreamingRef.current = streaming; + if (!wasStreaming || streaming) { + return; + } + + if (disabled || isMock) { + return; + } + + const lastAi = [...thread.messages].reverse().find((m) => m.type === "ai"); + const lastAiId = lastAi?.id ?? null; + if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) { + return; + } + lastGeneratedForAiIdRef.current = lastAiId; + + const recent = thread.messages + .filter((m) => m.type === "human" || m.type === "ai") + .map((m) => { + const role = m.type === "human" ? "user" : "assistant"; + const content = textOfMessage(m) ?? ""; + return { role, content }; + }) + .filter((m) => m.content.trim().length > 0) + .slice(-6); + + if (recent.length === 0) { + return; + } + + const controller = new AbortController(); + setFollowupsHidden(false); + setFollowupsLoading(true); + setFollowups([]); + + fetch(`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + messages: recent, + n: 3, + model_name: context.model_name ?? undefined, + }), + signal: controller.signal, + }) + .then(async (res) => { + if (!res.ok) { + return { suggestions: [] as string[] }; + } + return (await res.json()) as { suggestions?: string[] }; + }) + .then((data) => { + const suggestions = (data.suggestions ?? []) + .map((s) => (typeof s === "string" ? s.trim() : "")) + .filter((s) => s.length > 0) + .slice(0, 5); + setFollowups(suggestions); + }) + .catch(() => { + setFollowups([]); + }) + .finally(() => { + setFollowupsLoading(false); + }); + + return () => controller.abort(); + }, [context.model_name, disabled, isMock, status, thread.messages, threadId]); + return ( - - {extraHeader && ( -
-
- {extraHeader} +
+ + {extraHeader && ( +
+
+ {extraHeader} +
-
- )} - - {(attachment) => } - - - - - - + )} + + {(attachment) => } + + + + + + {/* TODO: Add more connectors here @@ -588,7 +743,65 @@ export function InputBox({ {!isNewThread && (
)} - + + + {!disabled && + !isNewThread && + !followupsHidden && + (followupsLoading || followups.length > 0) && ( +
+
+ {followupsLoading ? ( +
+ {t.inputBox.followupLoading} +
+ ) : ( + + {followups.map((s) => ( + handleFollowupClick(s)} + /> + ))} + + + )} +
+
+ )} + + + + + {t.inputBox.followupConfirmTitle} + + {t.inputBox.followupConfirmDescription} + + + + + + + + + +
); } diff --git a/frontend/src/core/i18n/locales/en-US.ts b/frontend/src/core/i18n/locales/en-US.ts index ce1c90b..bc4b84a 100644 --- a/frontend/src/core/i18n/locales/en-US.ts +++ b/frontend/src/core/i18n/locales/en-US.ts @@ -96,6 +96,12 @@ export const enUS: Translations = { searchModels: "Search models...", surpriseMe: "Surprise", surpriseMePrompt: "Surprise me", + followupLoading: "Generating follow-up questions...", + followupConfirmTitle: "Send suggestion?", + followupConfirmDescription: + "You already have text in the input. Choose how to send it.", + followupConfirmAppend: "Append & send", + followupConfirmReplace: "Replace & send", suggestions: [ { suggestion: "Write", diff --git a/frontend/src/core/i18n/locales/types.ts b/frontend/src/core/i18n/locales/types.ts index 06bb403..b385d45 100644 --- a/frontend/src/core/i18n/locales/types.ts +++ b/frontend/src/core/i18n/locales/types.ts @@ -76,6 +76,11 @@ export interface Translations { searchModels: string; surpriseMe: string; surpriseMePrompt: string; + followupLoading: string; + followupConfirmTitle: string; + followupConfirmDescription: string; + followupConfirmAppend: string; + followupConfirmReplace: string; suggestions: { suggestion: string; prompt: string; diff --git a/frontend/src/core/i18n/locales/zh-CN.ts b/frontend/src/core/i18n/locales/zh-CN.ts index 693ea82..d6d8031 100644 --- a/frontend/src/core/i18n/locales/zh-CN.ts +++ b/frontend/src/core/i18n/locales/zh-CN.ts @@ -92,6 +92,11 @@ export const zhCN: Translations = { searchModels: "搜索模型...", surpriseMe: "小惊喜", surpriseMePrompt: "给我一个小惊喜吧", + followupLoading: "正在生成可能的后续问题...", + followupConfirmTitle: "发送建议问题?", + followupConfirmDescription: "当前输入框已有内容,选择发送方式。", + followupConfirmAppend: "追加并发送", + followupConfirmReplace: "替换并发送", suggestions: [ { suggestion: "写作",