diff --git a/src/config/agents.py b/src/config/agents.py index 4454c56..e88a057 100644 --- a/src/config/agents.py +++ b/src/config/agents.py @@ -15,4 +15,5 @@ AGENT_LLM_MAP: dict[str, LLMType] = { "reporter": "basic", "podcast_script_writer": "basic", "ppt_composer": "basic", + "prose_writer": "basic", } diff --git a/src/prose/graph/builder.py b/src/prose/graph/builder.py new file mode 100644 index 0000000..d9bad5d --- /dev/null +++ b/src/prose/graph/builder.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import asyncio +import logging +from langgraph.graph import END, START, StateGraph + +from src.prose.graph.prose_continue_node import prose_continue_node +from src.prose.graph.prose_fix_node import prose_fix_node +from src.prose.graph.prose_improve_node import prose_improve_node +from src.prose.graph.prose_longer_node import prose_longer_node +from src.prose.graph.prose_shorter_node import prose_shorter_node +from src.prose.graph.prose_zap_node import prose_zap_node +from src.prose.graph.state import ProseState + + +def optional_node(state: ProseState): + return state["option"] + + +def build_graph(): + """Build and return the ppt workflow graph.""" + # build state graph + builder = StateGraph(ProseState) + builder.add_node("prose_continue", prose_continue_node) + builder.add_node("prose_improve", prose_improve_node) + builder.add_node("prose_shorter", prose_shorter_node) + builder.add_node("prose_longer", prose_longer_node) + builder.add_node("prose_fix", prose_fix_node) + builder.add_node("prose_zap", prose_zap_node) + builder.add_conditional_edges( + START, + optional_node, + { + "continue": "prose_continue", + "improve": "prose_improve", + "shorter": "prose_shorter", + "longer": "prose_longer", + "fix": "prose_fix", + "zap": "prose_zap", + }, + END, + ) + return builder.compile() + + +async def _test_workflow(): + workflow = build_graph() + events = workflow.astream( + { + "content": "The weather in Beijing is sunny", + "option": "continue", + }, + stream_mode="messages", + subgraphs=True, + ) + async for node, event in events: + e = event[0] + print({"id": e.id, "object": "chat.completion.chunk", "content": e.content}) + + +if __name__ == "__main__": + from dotenv import load_dotenv + + load_dotenv() + logging.basicConfig(level=logging.INFO) + asyncio.run(_test_workflow()) diff --git a/src/prose/graph/prose_continue_node.py b/src/prose/graph/prose_continue_node.py new file mode 100644 index 0000000..0c05f5c --- /dev/null +++ b/src/prose/graph/prose_continue_node.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prose.graph.state import ProseState + +logger = logging.getLogger(__name__) + + +def prose_continue_node(state: ProseState): + logger.info("Generating prose continue content...") + model = get_llm_by_type(AGENT_LLM_MAP["prose_writer"]) + prose_content = model.invoke( + [ + SystemMessage( + content=""" +You are an AI writing assistant that continues existing text based on context from prior text. +- Give more weight/priority to the later characters than the beginning ones. +- Limit your response to no more than 200 characters, but make sure to construct complete sentences. +- Use Markdown formatting when appropriate +""" + ), + HumanMessage(content=state["content"]), + ], + ) + return {"output": prose_content.content} diff --git a/src/prose/graph/prose_fix_node.py b/src/prose/graph/prose_fix_node.py new file mode 100644 index 0000000..000a83d --- /dev/null +++ b/src/prose/graph/prose_fix_node.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prose.graph.state import ProseState + +logger = logging.getLogger(__name__) + + +def prose_fix_node(state: ProseState): + logger.info("Generating prose fix content...") + model = get_llm_by_type(AGENT_LLM_MAP["prose_writer"]) + prose_content = model.invoke( + [ + SystemMessage( + content=""" +You are an AI writing assistant that fixes grammar and spelling errors in existing text. +- Limit your response to no more than 200 characters, but make sure to construct complete sentences. +- Use Markdown formatting when appropriate. +- If the text is already correct, just return the original text. +""" + ), + HumanMessage(content=f"The existing text is: {state['content']}"), + ], + ) + logger.info(f"prose_content: {prose_content}") + return {"output": prose_content.content} diff --git a/src/prose/graph/prose_improve_node.py b/src/prose/graph/prose_improve_node.py new file mode 100644 index 0000000..1d9d901 --- /dev/null +++ b/src/prose/graph/prose_improve_node.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prose.graph.state import ProseState + +logger = logging.getLogger(__name__) + + +def prose_improve_node(state: ProseState): + logger.info("Generating prose improve content...") + model = get_llm_by_type(AGENT_LLM_MAP["prose_writer"]) + prose_content = model.invoke( + [ + SystemMessage( + content=""" +You are an AI writing assistant that improves existing text. +- Limit your response to no more than 200 characters, but make sure to construct complete sentences. +- Use Markdown formatting when appropriate. +""" + ), + HumanMessage(content=f"The existing text is: {state['content']}"), + ], + ) + logger.info(f"prose_content: {prose_content}") + return {"output": prose_content.content} diff --git a/src/prose/graph/prose_longer_node.py b/src/prose/graph/prose_longer_node.py new file mode 100644 index 0000000..958cc48 --- /dev/null +++ b/src/prose/graph/prose_longer_node.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prose.graph.state import ProseState + +logger = logging.getLogger(__name__) + + +def prose_longer_node(state: ProseState): + logger.info("Generating prose longer content...") + model = get_llm_by_type(AGENT_LLM_MAP["prose_writer"]) + prose_content = model.invoke( + [ + SystemMessage( + content=""" +You are an AI writing assistant that lengthens existing text. +- Use Markdown formatting when appropriate. +""" + ), + HumanMessage(content=f"The existing text is: {state['content']}"), + ], + ) + logger.info(f"prose_content: {prose_content}") + return {"output": prose_content.content} diff --git a/src/prose/graph/prose_shorter_node.py b/src/prose/graph/prose_shorter_node.py new file mode 100644 index 0000000..3dbe26f --- /dev/null +++ b/src/prose/graph/prose_shorter_node.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prose.graph.state import ProseState + +logger = logging.getLogger(__name__) + + +def prose_shorter_node(state: ProseState): + logger.info("Generating prose shorter content...") + model = get_llm_by_type(AGENT_LLM_MAP["prose_writer"]) + prose_content = model.invoke( + [ + SystemMessage( + content=""" +You are an AI writing assistant that shortens existing text. +- Use Markdown formatting when appropriate. +""" + ), + HumanMessage(content=f"The existing text is: {state['content']}"), + ], + ) + logger.info(f"prose_content: {prose_content}") + return {"output": prose_content.content} diff --git a/src/prose/graph/prose_zap_node.py b/src/prose/graph/prose_zap_node.py new file mode 100644 index 0000000..1439f20 --- /dev/null +++ b/src/prose/graph/prose_zap_node.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prose.graph.state import ProseState + +logger = logging.getLogger(__name__) + + +def prose_zap_node(state: ProseState): + logger.info("Generating prose zap content...") + model = get_llm_by_type(AGENT_LLM_MAP["prose_writer"]) + prose_content = model.invoke( + [ + SystemMessage( + content=""" +You area an AI writing assistant that generates text based on a prompt. +- You take an input from the user and a command for manipulating the text." +- Use Markdown formatting when appropriate. +""" + ), + HumanMessage( + content=f"For this text: {state['content']}.\nYou have to respect the command: {state['command']}" + ), + ], + ) + logger.info(f"prose_content: {prose_content}") + return {"output": prose_content.content} diff --git a/src/prose/graph/state.py b/src/prose/graph/state.py new file mode 100644 index 0000000..fd4d92c --- /dev/null +++ b/src/prose/graph/state.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from langgraph.graph import MessagesState + + +class ProseState(MessagesState): + """State for the prose generation.""" + + # The content of the prose + content: str = "" + + # Prose writer option: continue, improve, shorter, longer, fix, zap + option: str = "" + + # The user custom command for the prose writer + command: str = "" + + # Output + output: str = "" diff --git a/src/server/app.py b/src/server/app.py index c910dac..07dd266 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -17,11 +17,13 @@ from langgraph.types import Command from src.graph.builder import build_graph_with_memory from src.podcast.graph.builder import build_graph as build_podcast_graph from src.ppt.graph.builder import build_graph as build_ppt_graph +from src.prose.graph.builder import build_graph as build_prose_graph from src.server.chat_request import ( ChatMessage, ChatRequest, GeneratePodcastRequest, GeneratePPTRequest, + GenerateProseRequest, TTSRequest, ) from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse @@ -254,6 +256,29 @@ async def generate_ppt(request: GeneratePPTRequest): raise HTTPException(status_code=500, detail=str(e)) +@app.post("/api/prose/generate") +async def generate_prose(request: GenerateProseRequest): + try: + logger.info(f"Generating prose for prompt: {request.prompt}") + workflow = build_prose_graph() + events = workflow.astream( + { + "content": request.prompt, + "option": request.option, + "command": request.command, + }, + stream_mode="messages", + subgraphs=True, + ) + return StreamingResponse( + (f"data: {event[0].content}\n\n" async for _, event in events), + media_type="text/event-stream", + ) + except Exception as e: + logger.exception(f"Error occurred during prose generation: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse) async def mcp_server_metadata(request: MCPServerMetadataRequest): """Get information about an MCP server.""" diff --git a/src/server/chat_request.py b/src/server/chat_request.py index e2abcde..6626fe1 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -74,3 +74,11 @@ class GeneratePodcastRequest(BaseModel): class GeneratePPTRequest(BaseModel): content: str = Field(..., description="The content of the ppt") + + +class GenerateProseRequest(BaseModel): + prompt: str = Field(..., description="The content of the prose") + option: str = Field(..., description="The option of the prose writer") + command: Optional[str] = Field( + "", description="The user custom command of the prose writer" + ) diff --git a/web/package.json b/web/package.json index c976280..261f232 100644 --- a/web/package.json +++ b/web/package.json @@ -17,7 +17,6 @@ "typecheck": "tsc --noEmit" }, "dependencies": { - "@ai-sdk/react": "^1.2.9", "@ant-design/icons": "^6.0.0", "@hookform/resolvers": "^5.0.1", "@nanostores/react": "github:ai/react", diff --git a/web/src/components/editor/generative/ai-selector.tsx b/web/src/components/editor/generative/ai-selector.tsx index b0f299d..0c676bf 100644 --- a/web/src/components/editor/generative/ai-selector.tsx +++ b/web/src/components/editor/generative/ai-selector.tsx @@ -2,11 +2,10 @@ import { Command, CommandInput } from "../../ui/command"; -import { useCompletion } from "@ai-sdk/react"; import { ArrowUp } from "lucide-react"; import { useEditor } from "novel"; import { addAIHighlight } from "novel"; -import { useState } from "react"; +import { useCallback, useState } from "react"; import Markdown from "react-markdown"; import { toast } from "sonner"; import { Button } from "../../ui/button"; @@ -15,6 +14,8 @@ import { ScrollArea } from "../../ui/scroll-area"; import AICompletionCommands from "./ai-completion-command"; import AISelectorCommands from "./ai-selector-commands"; import { LoadingOutlined } from "@ant-design/icons"; +import { resolveServiceURL } from "~/core/api/resolve-service-url"; +import { fetchStream } from "~/core/sse"; //TODO: I think it makes more sense to create a custom Tiptap extension for this functionality https://tiptap.dev/docs/editor/ai/introduction interface AISelectorProps { @@ -22,23 +23,72 @@ interface AISelectorProps { onOpenChange: (open: boolean) => void; } +function useProseCompletion() { + const [completion, setCompletion] = useState(""); + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(null); + + const complete = useCallback( + async (prompt: string, options?: { body?: Record }) => { + setIsLoading(true); + setError(null); + + try { + const response = await fetchStream( + resolveServiceURL("/api/prose/generate"), + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt, + ...options?.body, + }), + }, + ); + + let fullText = ""; + + // Process the streaming response + for await (const chunk of response) { + fullText += chunk.data; + setCompletion(fullText); + } + + setIsLoading(false); + return fullText; + } catch (e) { + const error = e instanceof Error ? e : new Error("An error occurred"); + setError(error); + toast.error(error.message); + setIsLoading(false); + throw error; + } + }, + [], + ); + + const reset = useCallback(() => { + setCompletion(""); + setError(null); + setIsLoading(false); + }, []); + + return { + completion, + complete, + isLoading, + error, + reset, + }; +} + export function AISelector({ onOpenChange }: AISelectorProps) { const { editor } = useEditor(); const [inputValue, setInputValue] = useState(""); - const { completion, complete, isLoading } = useCompletion({ - // id: "novel", - api: "/api/generate", - onResponse: (response) => { - if (response.status === 429) { - toast.error("You have reached your request limit for the day."); - return; - } - }, - onError: (e) => { - toast.error(e.message); - }, - }); + const { completion, complete, isLoading } = useProseCompletion(); if (!editor) return null; @@ -57,7 +107,7 @@ export function AISelector({ onOpenChange }: AISelectorProps) { )} {isLoading && ( -
+
AI is thinking
diff --git a/web/src/core/api/chat.ts b/web/src/core/api/chat.ts index 6f500cd..6e04483 100644 --- a/web/src/core/api/chat.ts +++ b/web/src/core/api/chat.ts @@ -9,7 +9,7 @@ import { sleep } from "../utils"; import { resolveServiceURL } from "./resolve-service-url"; import type { ChatEvent } from "./types"; -export function chatStream( +export async function* chatStream( userMessage: string, params: { thread_id: string; @@ -32,13 +32,19 @@ export function chatStream( if (location.search.includes("mock") || location.search.includes("replay=")) { return chatReplayStream(userMessage, params, options); } - return fetchStream(resolveServiceURL("chat/stream"), { + const stream = fetchStream(resolveServiceURL("chat/stream"), { body: JSON.stringify({ messages: [{ role: "user", content: userMessage }], ...params, }), signal: options.abortSignal, }); + for await (const event of stream) { + yield { + type: event.event, + data: JSON.parse(event.data), + } as ChatEvent; + } } async function* chatReplayStream( diff --git a/web/src/core/sse/StreamEvent.ts b/web/src/core/sse/StreamEvent.ts index fb91672..4d1e2e3 100644 --- a/web/src/core/sse/StreamEvent.ts +++ b/web/src/core/sse/StreamEvent.ts @@ -2,6 +2,6 @@ // SPDX-License-Identifier: MIT export interface StreamEvent { - type: string; - data: object; + event: string; + data: string; } diff --git a/web/src/core/sse/fetch-stream.ts b/web/src/core/sse/fetch-stream.ts index e19ea67..7fd4321 100644 --- a/web/src/core/sse/fetch-stream.ts +++ b/web/src/core/sse/fetch-stream.ts @@ -3,10 +3,10 @@ import { type StreamEvent } from "./StreamEvent"; -export async function* fetchStream( +export async function* fetchStream( url: string, init: RequestInit, -): AsyncIterable { +): AsyncIterable { const response = await fetch(url, { method: "POST", headers: { @@ -39,7 +39,7 @@ export async function* fetchStream( } const chunk = buffer.slice(0, index); buffer = buffer.slice(index + 2); - const event = parseEvent(chunk); + const event = parseEvent(chunk); if (event) { yield event; } @@ -47,9 +47,9 @@ export async function* fetchStream( } } -function parseEvent(chunk: string) { - let resultType = "message"; - let resultData: object | null = null; +function parseEvent(chunk: string) { + let resultEvent = "message"; + let resultData: string | null = null; for (const line of chunk.split("\n")) { const pos = line.indexOf(": "); if (pos === -1) { @@ -58,16 +58,16 @@ function parseEvent(chunk: string) { const key = line.slice(0, pos); const value = line.slice(pos + 2); if (key === "event") { - resultType = value; + resultEvent = value; } else if (key === "data") { - resultData = JSON.parse(value); + resultData = value; } } - if (resultType === "message" && resultData === null) { + if (resultEvent === "message" && resultData === null) { return undefined; } return { - type: resultType, + event: resultEvent, data: resultData, - } as T; + } as StreamEvent; }