mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-12 18:14:46 +08:00
feat: implement enhance prompt (#294)
* feat: implement enhance prompt * add unit test * fix prompt * fix: fix eslint and compiling issues * feat: add border-beam animation * fix: fix importing issues --------- Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
@@ -20,11 +20,13 @@ from src.graph.builder import build_graph_with_memory
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||
from src.prose.graph.builder import build_graph as build_prose_graph
|
||||
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
|
||||
from src.rag.builder import build_retriever
|
||||
from src.rag.retriever import Resource
|
||||
from src.server.chat_request import (
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
EnhancePromptRequest,
|
||||
GeneratePodcastRequest,
|
||||
GeneratePPTRequest,
|
||||
GenerateProseRequest,
|
||||
@@ -300,6 +302,50 @@ async def generate_prose(request: GenerateProseRequest):
|
||||
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
||||
|
||||
|
||||
@app.post("/api/prompt/enhance")
|
||||
async def enhance_prompt(request: EnhancePromptRequest):
|
||||
try:
|
||||
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
||||
logger.info(f"Enhancing prompt: {sanitized_prompt}")
|
||||
|
||||
# Convert string report_style to ReportStyle enum
|
||||
report_style = None
|
||||
if request.report_style:
|
||||
try:
|
||||
# Handle both uppercase and lowercase input
|
||||
style_mapping = {
|
||||
"ACADEMIC": ReportStyle.ACADEMIC,
|
||||
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
|
||||
"NEWS": ReportStyle.NEWS,
|
||||
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
|
||||
"academic": ReportStyle.ACADEMIC,
|
||||
"popular_science": ReportStyle.POPULAR_SCIENCE,
|
||||
"news": ReportStyle.NEWS,
|
||||
"social_media": ReportStyle.SOCIAL_MEDIA,
|
||||
}
|
||||
report_style = style_mapping.get(
|
||||
request.report_style, ReportStyle.ACADEMIC
|
||||
)
|
||||
except Exception:
|
||||
# If invalid style, default to ACADEMIC
|
||||
report_style = ReportStyle.ACADEMIC
|
||||
else:
|
||||
report_style = ReportStyle.ACADEMIC
|
||||
|
||||
workflow = build_prompt_enhancer_graph()
|
||||
final_state = workflow.invoke(
|
||||
{
|
||||
"prompt": request.prompt,
|
||||
"context": request.context,
|
||||
"report_style": report_style,
|
||||
}
|
||||
)
|
||||
return {"result": final_state["output"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
||||
|
||||
|
||||
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
||||
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
"""Get information about an MCP server."""
|
||||
|
||||
@@ -94,3 +94,13 @@ class GenerateProseRequest(BaseModel):
|
||||
command: Optional[str] = Field(
|
||||
"", description="The user custom command of the prose writer"
|
||||
)
|
||||
|
||||
|
||||
class EnhancePromptRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The original prompt to enhance")
|
||||
context: Optional[str] = Field(
|
||||
"", description="Additional context about the intended use"
|
||||
)
|
||||
report_style: Optional[str] = Field(
|
||||
"academic", description="The style of the report"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user