2025-04-17 11:34:42 +08:00
|
|
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
2025-05-07 17:23:25 +08:00
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Dict
|
2025-05-28 01:21:40 -07:00
|
|
|
import os
|
2025-05-07 17:23:25 +08:00
|
|
|
|
2025-04-07 16:25:55 +08:00
|
|
|
from langchain_openai import ChatOpenAI
|
2025-05-07 17:23:25 +08:00
|
|
|
|
2025-04-07 16:25:55 +08:00
|
|
|
from src.config import load_yaml_config
|
|
|
|
|
from src.config.agents import LLMType
|
|
|
|
|
|
|
|
|
|
# Cache for LLM instances
|
|
|
|
|
_llm_cache: dict[LLMType, ChatOpenAI] = {}
|
|
|
|
|
|
|
|
|
|
|
2025-05-28 01:21:40 -07:00
|
|
|
def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Get LLM configuration from environment variables.
|
|
|
|
|
Environment variables should follow the format: {LLM_TYPE}__{KEY}
|
|
|
|
|
e.g., BASIC_MODEL__api_key, BASIC_MODEL__base_url
|
|
|
|
|
"""
|
|
|
|
|
prefix = f"{llm_type.upper()}_MODEL__"
|
|
|
|
|
conf = {}
|
|
|
|
|
for key, value in os.environ.items():
|
|
|
|
|
if key.startswith(prefix):
|
|
|
|
|
conf_key = key[len(prefix) :].lower()
|
|
|
|
|
conf[conf_key] = value
|
|
|
|
|
return conf
|
|
|
|
|
|
|
|
|
|
|
2025-04-07 16:25:55 +08:00
|
|
|
def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
|
|
|
|
|
llm_type_map = {
|
2025-05-28 01:21:40 -07:00
|
|
|
"reasoning": conf.get("REASONING_MODEL", {}),
|
|
|
|
|
"basic": conf.get("BASIC_MODEL", {}),
|
|
|
|
|
"vision": conf.get("VISION_MODEL", {}),
|
2025-04-07 16:25:55 +08:00
|
|
|
}
|
|
|
|
|
llm_conf = llm_type_map.get(llm_type)
|
|
|
|
|
if not isinstance(llm_conf, dict):
|
|
|
|
|
raise ValueError(f"Invalid LLM Conf: {llm_type}")
|
2025-05-28 01:21:40 -07:00
|
|
|
# Get configuration from environment variables
|
|
|
|
|
env_conf = _get_env_llm_conf(llm_type)
|
|
|
|
|
|
|
|
|
|
# Merge configurations, with environment variables taking precedence
|
|
|
|
|
merged_conf = {**llm_conf, **env_conf}
|
|
|
|
|
|
|
|
|
|
if not merged_conf:
|
|
|
|
|
raise ValueError(f"Unknown LLM Conf: {llm_type}")
|
|
|
|
|
|
|
|
|
|
return ChatOpenAI(**merged_conf)
|
2025-04-07 16:25:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_by_type(
|
|
|
|
|
llm_type: LLMType,
|
|
|
|
|
) -> ChatOpenAI:
|
|
|
|
|
"""
|
|
|
|
|
Get LLM instance by type. Returns cached instance if available.
|
|
|
|
|
"""
|
|
|
|
|
if llm_type in _llm_cache:
|
|
|
|
|
return _llm_cache[llm_type]
|
|
|
|
|
|
|
|
|
|
conf = load_yaml_config(
|
|
|
|
|
str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
|
|
|
|
|
)
|
|
|
|
|
llm = _create_llm_use_conf(llm_type, conf)
|
|
|
|
|
_llm_cache[llm_type] = llm
|
|
|
|
|
return llm
|
|
|
|
|
|
|
|
|
|
|
2025-05-07 17:23:25 +08:00
|
|
|
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
|
|
|
|
# reasoning_llm = get_llm_by_type("reasoning")
|
|
|
|
|
# vl_llm = get_llm_by_type("vision")
|