feat: add deep think feature (#311)

* feat: implement backend logic

* feat: implement api/config endpoint

* rename the symbol

* feat: re-implement configuration at client-side

* feat: add client-side of deep thinking

* fix backend bug

* feat: add reasoning block

* docs: update readme

* fix: translate into English

* fix: change icon to lightbulb

* feat: ignore more bad cases

* feat: adjust thinking layout, and implement auto scrolling

* docs: add comments

---------

Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
DanielWalnut
2025-06-14 13:12:43 +08:00
committed by GitHub
parent a7315b46df
commit 19fa1e97c3
40 changed files with 2292 additions and 1102 deletions

View File

@@ -23,6 +23,7 @@ class Configuration:
max_search_results: int = 3 # Maximum number of search results
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
report_style: str = ReportStyle.ACADEMIC.value # Report style
enable_deep_thinking: bool = False # Whether to enable deep thinking
@classmethod
def from_runnable_config(

View File

@@ -101,8 +101,10 @@ def planner_node(
}
]
if AGENT_LLM_MAP["planner"] == "basic":
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
if configurable.enable_deep_thinking:
llm = get_llm_by_type("reasoning")
elif AGENT_LLM_MAP["planner"] == "basic":
llm = get_llm_by_type("basic").with_structured_output(
Plan,
method="json_schema",
strict=True,
@@ -115,7 +117,7 @@ def planner_node(
return Command(goto="reporter")
full_response = ""
if AGENT_LLM_MAP["planner"] == "basic":
if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking:
response = llm.invoke(messages)
full_response = response.model_dump_json(indent=4, exclude_none=True)
else:

View File

@@ -6,6 +6,8 @@ from typing import Any, Dict
import os
from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek
from typing import get_args
from src.config import load_yaml_config
from src.config.agents import LLMType
@@ -14,6 +16,20 @@ from src.config.agents import LLMType
_llm_cache: dict[LLMType, ChatOpenAI] = {}
def _get_config_file_path() -> str:
"""Get the path to the configuration file."""
return str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
def _get_llm_type_config_keys() -> dict[str, str]:
"""Get mapping of LLM types to their configuration keys."""
return {
"reasoning": "REASONING_MODEL",
"basic": "BASIC_MODEL",
"vision": "VISION_MODEL",
}
def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
"""
Get LLM configuration from environment variables.
@@ -29,15 +45,20 @@ def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
return conf
def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
llm_type_map = {
"reasoning": conf.get("REASONING_MODEL", {}),
"basic": conf.get("BASIC_MODEL", {}),
"vision": conf.get("VISION_MODEL", {}),
}
llm_conf = llm_type_map.get(llm_type)
def _create_llm_use_conf(
llm_type: LLMType, conf: Dict[str, Any]
) -> ChatOpenAI | ChatDeepSeek:
"""Create LLM instance using configuration."""
llm_type_config_keys = _get_llm_type_config_keys()
config_key = llm_type_config_keys.get(llm_type)
if not config_key:
raise ValueError(f"Unknown LLM type: {llm_type}")
llm_conf = conf.get(config_key, {})
if not isinstance(llm_conf, dict):
raise ValueError(f"Invalid LLM Conf: {llm_type}")
raise ValueError(f"Invalid LLM configuration for {llm_type}: {llm_conf}")
# Get configuration from environment variables
env_conf = _get_env_llm_conf(llm_type)
@@ -45,9 +66,16 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
merged_conf = {**llm_conf, **env_conf}
if not merged_conf:
raise ValueError(f"Unknown LLM Conf: {llm_type}")
raise ValueError(f"No configuration found for LLM type: {llm_type}")
return ChatOpenAI(**merged_conf)
if llm_type == "reasoning":
merged_conf["api_base"] = merged_conf.pop("base_url", None)
return (
ChatOpenAI(**merged_conf)
if llm_type != "reasoning"
else ChatDeepSeek(**merged_conf)
)
def get_llm_by_type(
@@ -59,14 +87,49 @@ def get_llm_by_type(
if llm_type in _llm_cache:
return _llm_cache[llm_type]
conf = load_yaml_config(
str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
)
conf = load_yaml_config(_get_config_file_path())
llm = _create_llm_use_conf(llm_type, conf)
_llm_cache[llm_type] = llm
return llm
def get_configured_llm_models() -> dict[str, list[str]]:
"""
Get all configured LLM models grouped by type.
Returns:
Dictionary mapping LLM type to list of configured model names.
"""
try:
conf = load_yaml_config(_get_config_file_path())
llm_type_config_keys = _get_llm_type_config_keys()
configured_models: dict[str, list[str]] = {}
for llm_type in get_args(LLMType):
# Get configuration from YAML file
config_key = llm_type_config_keys.get(llm_type, "")
yaml_conf = conf.get(config_key, {}) if config_key else {}
# Get configuration from environment variables
env_conf = _get_env_llm_conf(llm_type)
# Merge configurations, with environment variables taking precedence
merged_conf = {**yaml_conf, **env_conf}
# Check if model is configured
model_name = merged_conf.get("model")
if model_name:
configured_models.setdefault(llm_type, []).append(model_name)
return configured_models
except Exception as e:
# Log error and return empty dict to avoid breaking the application
print(f"Warning: Failed to load LLM configuration: {e}")
return {}
# In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision")

View File

@@ -24,7 +24,6 @@ from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhanc
from src.rag.builder import build_retriever
from src.rag.retriever import Resource
from src.server.chat_request import (
ChatMessage,
ChatRequest,
EnhancePromptRequest,
GeneratePodcastRequest,
@@ -39,6 +38,8 @@ from src.server.rag_request import (
RAGResourceRequest,
RAGResourcesResponse,
)
from src.server.config_request import ConfigResponse
from src.llms.llm import get_configured_llm_models
from src.tools import VolcengineTTS
logger = logging.getLogger(__name__)
@@ -81,6 +82,7 @@ async def chat_stream(request: ChatRequest):
request.mcp_settings,
request.enable_background_investigation,
request.report_style,
request.enable_deep_thinking,
),
media_type="text/event-stream",
)
@@ -98,6 +100,7 @@ async def _astream_workflow_generator(
mcp_settings: dict,
enable_background_investigation: bool,
report_style: ReportStyle,
enable_deep_thinking: bool,
):
input_ = {
"messages": messages,
@@ -125,6 +128,7 @@ async def _astream_workflow_generator(
"max_search_results": max_search_results,
"mcp_settings": mcp_settings,
"report_style": report_style.value,
"enable_deep_thinking": enable_deep_thinking,
},
stream_mode=["messages", "updates"],
subgraphs=True,
@@ -156,6 +160,10 @@ async def _astream_workflow_generator(
"role": "assistant",
"content": message_chunk.content,
}
if message_chunk.additional_kwargs.get("reasoning_content"):
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
"reasoning_content"
]
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
@@ -399,3 +407,12 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
if retriever:
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
return RAGResourcesResponse(resources=[])
@app.get("/api/config", response_model=ConfigResponse)
async def config():
"""Get the config of the server."""
return ConfigResponse(
rag=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER),
models=get_configured_llm_models(),
)

View File

@@ -62,6 +62,9 @@ class ChatRequest(BaseModel):
report_style: Optional[ReportStyle] = Field(
ReportStyle.ACADEMIC, description="The style of the report"
)
enable_deep_thinking: Optional[bool] = Field(
False, description="Whether to enable deep thinking"
)
class TTSRequest(BaseModel):

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from pydantic import BaseModel, Field
from src.server.rag_request import RAGConfigResponse
class ConfigResponse(BaseModel):
"""Response model for server config."""
rag: RAGConfigResponse = Field(..., description="The config of the RAG")
models: dict[str, list[str]] = Field(..., description="The configured models")