mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-09 16:54:46 +08:00
feat: add deep think feature (#311)
* feat: implement backend logic * feat: implement api/config endpoint * rename the symbol * feat: re-implement configuration at client-side * feat: add client-side of deep thinking * fix backend bug * feat: add reasoning block * docs: update readme * fix: translate into English * fix: change icon to lightbulb * feat: ignore more bad cases * feat: adjust thinking layout, and implement auto scrolling * docs: add comments --------- Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
@@ -23,6 +23,7 @@ class Configuration:
|
||||
max_search_results: int = 3 # Maximum number of search results
|
||||
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
|
||||
report_style: str = ReportStyle.ACADEMIC.value # Report style
|
||||
enable_deep_thinking: bool = False # Whether to enable deep thinking
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
||||
@@ -101,8 +101,10 @@ def planner_node(
|
||||
}
|
||||
]
|
||||
|
||||
if AGENT_LLM_MAP["planner"] == "basic":
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
|
||||
if configurable.enable_deep_thinking:
|
||||
llm = get_llm_by_type("reasoning")
|
||||
elif AGENT_LLM_MAP["planner"] == "basic":
|
||||
llm = get_llm_by_type("basic").with_structured_output(
|
||||
Plan,
|
||||
method="json_schema",
|
||||
strict=True,
|
||||
@@ -115,7 +117,7 @@ def planner_node(
|
||||
return Command(goto="reporter")
|
||||
|
||||
full_response = ""
|
||||
if AGENT_LLM_MAP["planner"] == "basic":
|
||||
if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking:
|
||||
response = llm.invoke(messages)
|
||||
full_response = response.model_dump_json(indent=4, exclude_none=True)
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, Dict
|
||||
import os
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from typing import get_args
|
||||
|
||||
from src.config import load_yaml_config
|
||||
from src.config.agents import LLMType
|
||||
@@ -14,6 +16,20 @@ from src.config.agents import LLMType
|
||||
_llm_cache: dict[LLMType, ChatOpenAI] = {}
|
||||
|
||||
|
||||
def _get_config_file_path() -> str:
|
||||
"""Get the path to the configuration file."""
|
||||
return str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
|
||||
|
||||
|
||||
def _get_llm_type_config_keys() -> dict[str, str]:
|
||||
"""Get mapping of LLM types to their configuration keys."""
|
||||
return {
|
||||
"reasoning": "REASONING_MODEL",
|
||||
"basic": "BASIC_MODEL",
|
||||
"vision": "VISION_MODEL",
|
||||
}
|
||||
|
||||
|
||||
def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get LLM configuration from environment variables.
|
||||
@@ -29,15 +45,20 @@ def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
|
||||
return conf
|
||||
|
||||
|
||||
def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
|
||||
llm_type_map = {
|
||||
"reasoning": conf.get("REASONING_MODEL", {}),
|
||||
"basic": conf.get("BASIC_MODEL", {}),
|
||||
"vision": conf.get("VISION_MODEL", {}),
|
||||
}
|
||||
llm_conf = llm_type_map.get(llm_type)
|
||||
def _create_llm_use_conf(
|
||||
llm_type: LLMType, conf: Dict[str, Any]
|
||||
) -> ChatOpenAI | ChatDeepSeek:
|
||||
"""Create LLM instance using configuration."""
|
||||
llm_type_config_keys = _get_llm_type_config_keys()
|
||||
config_key = llm_type_config_keys.get(llm_type)
|
||||
|
||||
if not config_key:
|
||||
raise ValueError(f"Unknown LLM type: {llm_type}")
|
||||
|
||||
llm_conf = conf.get(config_key, {})
|
||||
if not isinstance(llm_conf, dict):
|
||||
raise ValueError(f"Invalid LLM Conf: {llm_type}")
|
||||
raise ValueError(f"Invalid LLM configuration for {llm_type}: {llm_conf}")
|
||||
|
||||
# Get configuration from environment variables
|
||||
env_conf = _get_env_llm_conf(llm_type)
|
||||
|
||||
@@ -45,9 +66,16 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
|
||||
merged_conf = {**llm_conf, **env_conf}
|
||||
|
||||
if not merged_conf:
|
||||
raise ValueError(f"Unknown LLM Conf: {llm_type}")
|
||||
raise ValueError(f"No configuration found for LLM type: {llm_type}")
|
||||
|
||||
return ChatOpenAI(**merged_conf)
|
||||
if llm_type == "reasoning":
|
||||
merged_conf["api_base"] = merged_conf.pop("base_url", None)
|
||||
|
||||
return (
|
||||
ChatOpenAI(**merged_conf)
|
||||
if llm_type != "reasoning"
|
||||
else ChatDeepSeek(**merged_conf)
|
||||
)
|
||||
|
||||
|
||||
def get_llm_by_type(
|
||||
@@ -59,14 +87,49 @@ def get_llm_by_type(
|
||||
if llm_type in _llm_cache:
|
||||
return _llm_cache[llm_type]
|
||||
|
||||
conf = load_yaml_config(
|
||||
str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
|
||||
)
|
||||
conf = load_yaml_config(_get_config_file_path())
|
||||
llm = _create_llm_use_conf(llm_type, conf)
|
||||
_llm_cache[llm_type] = llm
|
||||
return llm
|
||||
|
||||
|
||||
def get_configured_llm_models() -> dict[str, list[str]]:
|
||||
"""
|
||||
Get all configured LLM models grouped by type.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping LLM type to list of configured model names.
|
||||
"""
|
||||
try:
|
||||
conf = load_yaml_config(_get_config_file_path())
|
||||
llm_type_config_keys = _get_llm_type_config_keys()
|
||||
|
||||
configured_models: dict[str, list[str]] = {}
|
||||
|
||||
for llm_type in get_args(LLMType):
|
||||
# Get configuration from YAML file
|
||||
config_key = llm_type_config_keys.get(llm_type, "")
|
||||
yaml_conf = conf.get(config_key, {}) if config_key else {}
|
||||
|
||||
# Get configuration from environment variables
|
||||
env_conf = _get_env_llm_conf(llm_type)
|
||||
|
||||
# Merge configurations, with environment variables taking precedence
|
||||
merged_conf = {**yaml_conf, **env_conf}
|
||||
|
||||
# Check if model is configured
|
||||
model_name = merged_conf.get("model")
|
||||
if model_name:
|
||||
configured_models.setdefault(llm_type, []).append(model_name)
|
||||
|
||||
return configured_models
|
||||
|
||||
except Exception as e:
|
||||
# Log error and return empty dict to avoid breaking the application
|
||||
print(f"Warning: Failed to load LLM configuration: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
||||
# reasoning_llm = get_llm_by_type("reasoning")
|
||||
# vl_llm = get_llm_by_type("vision")
|
||||
|
||||
@@ -24,7 +24,6 @@ from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhanc
|
||||
from src.rag.builder import build_retriever
|
||||
from src.rag.retriever import Resource
|
||||
from src.server.chat_request import (
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
EnhancePromptRequest,
|
||||
GeneratePodcastRequest,
|
||||
@@ -39,6 +38,8 @@ from src.server.rag_request import (
|
||||
RAGResourceRequest,
|
||||
RAGResourcesResponse,
|
||||
)
|
||||
from src.server.config_request import ConfigResponse
|
||||
from src.llms.llm import get_configured_llm_models
|
||||
from src.tools import VolcengineTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -81,6 +82,7 @@ async def chat_stream(request: ChatRequest):
|
||||
request.mcp_settings,
|
||||
request.enable_background_investigation,
|
||||
request.report_style,
|
||||
request.enable_deep_thinking,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -98,6 +100,7 @@ async def _astream_workflow_generator(
|
||||
mcp_settings: dict,
|
||||
enable_background_investigation: bool,
|
||||
report_style: ReportStyle,
|
||||
enable_deep_thinking: bool,
|
||||
):
|
||||
input_ = {
|
||||
"messages": messages,
|
||||
@@ -125,6 +128,7 @@ async def _astream_workflow_generator(
|
||||
"max_search_results": max_search_results,
|
||||
"mcp_settings": mcp_settings,
|
||||
"report_style": report_style.value,
|
||||
"enable_deep_thinking": enable_deep_thinking,
|
||||
},
|
||||
stream_mode=["messages", "updates"],
|
||||
subgraphs=True,
|
||||
@@ -156,6 +160,10 @@ async def _astream_workflow_generator(
|
||||
"role": "assistant",
|
||||
"content": message_chunk.content,
|
||||
}
|
||||
if message_chunk.additional_kwargs.get("reasoning_content"):
|
||||
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
|
||||
"reasoning_content"
|
||||
]
|
||||
if message_chunk.response_metadata.get("finish_reason"):
|
||||
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
||||
"finish_reason"
|
||||
@@ -399,3 +407,12 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
||||
if retriever:
|
||||
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
|
||||
return RAGResourcesResponse(resources=[])
|
||||
|
||||
|
||||
@app.get("/api/config", response_model=ConfigResponse)
|
||||
async def config():
|
||||
"""Get the config of the server."""
|
||||
return ConfigResponse(
|
||||
rag=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER),
|
||||
models=get_configured_llm_models(),
|
||||
)
|
||||
|
||||
@@ -62,6 +62,9 @@ class ChatRequest(BaseModel):
|
||||
report_style: Optional[ReportStyle] = Field(
|
||||
ReportStyle.ACADEMIC, description="The style of the report"
|
||||
)
|
||||
enable_deep_thinking: Optional[bool] = Field(
|
||||
False, description="Whether to enable deep thinking"
|
||||
)
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
|
||||
13
src/server/config_request.py
Normal file
13
src/server/config_request.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.server.rag_request import RAGConfigResponse
|
||||
|
||||
|
||||
class ConfigResponse(BaseModel):
|
||||
"""Response model for server config."""
|
||||
|
||||
rag: RAGConfigResponse = Field(..., description="The config of the RAG")
|
||||
models: dict[str, list[str]] = Field(..., description="The configured models")
|
||||
Reference in New Issue
Block a user