diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 1f8f55f..7c4a625 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -101,10 +101,7 @@ def planner_node( } ] - if ( - AGENT_LLM_MAP["planner"] == "basic" - and not configurable.enable_deep_thinking - ): + if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking: llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output( Plan, method="json_mode", @@ -117,10 +114,7 @@ def planner_node( return Command(goto="reporter") full_response = "" - if ( - AGENT_LLM_MAP["planner"] == "basic" - and not configurable.enable_deep_thinking - ): + if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking: response = llm.invoke(messages) full_response = response.model_dump_json(indent=4, exclude_none=True) else: diff --git a/src/llms/llm.py b/src/llms/llm.py index 9777010..62fff28 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -7,6 +7,7 @@ import os from langchain_openai import ChatOpenAI from langchain_deepseek import ChatDeepSeek +from typing import get_args from src.config import load_yaml_config from src.config.agents import LLMType @@ -15,6 +16,20 @@ from src.config.agents import LLMType _llm_cache: dict[LLMType, ChatOpenAI] = {} +def _get_config_file_path() -> str: + """Get the path to the configuration file.""" + return str((Path(__file__).parent.parent.parent / "conf.yaml").resolve()) + + +def _get_llm_type_config_keys() -> dict[str, str]: + """Get mapping of LLM types to their configuration keys.""" + return { + "reasoning": "REASONING_MODEL", + "basic": "BASIC_MODEL", + "vision": "VISION_MODEL", + } + + def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]: """ Get LLM configuration from environment variables. @@ -33,14 +48,17 @@ def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]: def _create_llm_use_conf( llm_type: LLMType, conf: Dict[str, Any] ) -> ChatOpenAI | ChatDeepSeek: - llm_type_map = { - "reasoning": conf.get("REASONING_MODEL", {}), - "basic": conf.get("BASIC_MODEL", {}), - "vision": conf.get("VISION_MODEL", {}), - } - llm_conf = llm_type_map.get(llm_type) + """Create LLM instance using configuration.""" + llm_type_config_keys = _get_llm_type_config_keys() + config_key = llm_type_config_keys.get(llm_type) + + if not config_key: + raise ValueError(f"Unknown LLM type: {llm_type}") + + llm_conf = conf.get(config_key, {}) if not isinstance(llm_conf, dict): - raise ValueError(f"Invalid LLM Conf: {llm_type}") + raise ValueError(f"Invalid LLM configuration for {llm_type}: {llm_conf}") + # Get configuration from environment variables env_conf = _get_env_llm_conf(llm_type) @@ -48,10 +66,10 @@ def _create_llm_use_conf( merged_conf = {**llm_conf, **env_conf} if not merged_conf: - raise ValueError(f"Unknown LLM Conf: {llm_type}") + raise ValueError(f"No configuration found for LLM type: {llm_type}") if llm_type == "reasoning": - merged_conf["api_base"] = merged_conf.pop("base_url") + merged_conf["api_base"] = merged_conf.pop("base_url", None) return ( ChatOpenAI(**merged_conf) @@ -69,14 +87,49 @@ def get_llm_by_type( if llm_type in _llm_cache: return _llm_cache[llm_type] - conf = load_yaml_config( - str((Path(__file__).parent.parent.parent / "conf.yaml").resolve()) - ) + conf = load_yaml_config(_get_config_file_path()) llm = _create_llm_use_conf(llm_type, conf) _llm_cache[llm_type] = llm return llm +def get_configured_llms() -> dict[str, list[str]]: + """ + Get all configured LLM models grouped by type. + + Returns: + Dictionary mapping LLM type to list of configured model names. + """ + try: + conf = load_yaml_config(_get_config_file_path()) + llm_type_config_keys = _get_llm_type_config_keys() + + configured_llms: dict[str, list[str]] = {} + + for llm_type in get_args(LLMType): + # Get configuration from YAML file + config_key = llm_type_config_keys.get(llm_type, "") + yaml_conf = conf.get(config_key, {}) if config_key else {} + + # Get configuration from environment variables + env_conf = _get_env_llm_conf(llm_type) + + # Merge configurations, with environment variables taking precedence + merged_conf = {**yaml_conf, **env_conf} + + # Check if model is configured + model_name = merged_conf.get("model") + if model_name: + configured_llms.setdefault(llm_type, []).append(model_name) + + return configured_llms + + except Exception as e: + # Log error and return empty dict to avoid breaking the application + print(f"Warning: Failed to load LLM configuration: {e}") + return {} + + # In the future, we will use reasoning_llm and vl_llm for different purposes # reasoning_llm = get_llm_by_type("reasoning") # vl_llm = get_llm_by_type("vision") diff --git a/src/server/app.py b/src/server/app.py index aac13a2..7fd99da 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -38,6 +38,8 @@ from src.server.rag_request import ( RAGResourceRequest, RAGResourcesResponse, ) +from src.server.config_request import ConfigResponse +from src.llms.llm import get_configured_llms from src.tools import VolcengineTTS logger = logging.getLogger(__name__) @@ -405,3 +407,12 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]): if retriever: return RAGResourcesResponse(resources=retriever.list_resources(request.query)) return RAGResourcesResponse(resources=[]) + + +@app.get("/api/config", response_model=ConfigResponse) +async def config(): + """Get the config of the server.""" + return ConfigResponse( + rag_config=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER), + configured_llms=get_configured_llms(), + ) diff --git a/src/server/config_request.py b/src/server/config_request.py new file mode 100644 index 0000000..7203f51 --- /dev/null +++ b/src/server/config_request.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from pydantic import BaseModel, Field + +from src.server.rag_request import RAGConfigResponse + + +class ConfigResponse(BaseModel): + """Response model for server config.""" + + rag_config: RAGConfigResponse = Field(..., description="The config of the RAG") + configured_llms: dict[str, list[str]] = Field(..., description="The configured LLM")