refactor: simplify style mapping by using upper case only (#378)

* improve: add abort btn to abort the mcp add request.

* refactor: simplify style mapping by using upper case only

* format: execute uv run black --preview . to format python files.
This commit is contained in:
Abeautifulsnow
2025-07-04 08:27:20 +08:00
committed by GitHub
parent be893eae2b
commit 7ad11bf86c

View File

@@ -11,16 +11,17 @@ from uuid import uuid4
from fastapi import FastAPI, HTTPException, Query from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage, BaseMessage from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.types import Command from langgraph.types import Command
from src.config.report_style import ReportStyle from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory from src.graph.builder import build_graph_with_memory
from src.llms.llm import get_configured_llm_models
from src.podcast.graph.builder import build_graph as build_podcast_graph from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph from src.ppt.graph.builder import build_graph as build_ppt_graph
from src.prose.graph.builder import build_graph as build_prose_graph
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
from src.prose.graph.builder import build_graph as build_prose_graph
from src.rag.builder import build_retriever from src.rag.builder import build_retriever
from src.rag.retriever import Resource from src.rag.retriever import Resource
from src.server.chat_request import ( from src.server.chat_request import (
@@ -31,6 +32,7 @@ from src.server.chat_request import (
GenerateProseRequest, GenerateProseRequest,
TTSRequest, TTSRequest,
) )
from src.server.config_request import ConfigResponse
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
from src.server.mcp_utils import load_mcp_tools from src.server.mcp_utils import load_mcp_tools
from src.server.rag_request import ( from src.server.rag_request import (
@@ -38,8 +40,6 @@ from src.server.rag_request import (
RAGResourceRequest, RAGResourceRequest,
RAGResourcesResponse, RAGResourcesResponse,
) )
from src.server.config_request import ConfigResponse
from src.llms.llm import get_configured_llm_models
from src.tools import VolcengineTTS from src.tools import VolcengineTTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -327,13 +327,9 @@ async def enhance_prompt(request: EnhancePromptRequest):
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE, "POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
"NEWS": ReportStyle.NEWS, "NEWS": ReportStyle.NEWS,
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA, "SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
"academic": ReportStyle.ACADEMIC,
"popular_science": ReportStyle.POPULAR_SCIENCE,
"news": ReportStyle.NEWS,
"social_media": ReportStyle.SOCIAL_MEDIA,
} }
report_style = style_mapping.get( report_style = style_mapping.get(
request.report_style, ReportStyle.ACADEMIC request.report_style.upper(), ReportStyle.ACADEMIC
) )
except Exception: except Exception:
# If invalid style, default to ACADEMIC # If invalid style, default to ACADEMIC