mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-17 03:34:45 +08:00
* test: add background node unit test Change-Id: Ia99f5a1687464387dcb01bbee04deaa371c6e490 * test: add background node unit test Change-Id: I9aabcf02ff04fda40c56f3ea22abe6b8f93bf9b6 * test: fix test error Change-Id: I3997dc53a2cfaa35501a1fbda5902ee15528124e * test: fix unit test error Change-Id: If4c4cd10673e76a30945674c7cda198aeabf28d0 * test: fix unit test error Change-Id: I3dd7a6179132e5497a30ada443d88de0c47af3d4
56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Dict
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
from src.config import load_yaml_config
|
|
from src.config.agents import LLMType
|
|
|
|
# Cache for LLM instances
|
|
_llm_cache: dict[LLMType, ChatOpenAI] = {}
|
|
|
|
|
|
def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> ChatOpenAI:
|
|
llm_type_map = {
|
|
"reasoning": conf.get("REASONING_MODEL"),
|
|
"basic": conf.get("BASIC_MODEL"),
|
|
"vision": conf.get("VISION_MODEL"),
|
|
}
|
|
llm_conf = llm_type_map.get(llm_type)
|
|
if not llm_conf:
|
|
raise ValueError(f"Unknown LLM type: {llm_type}")
|
|
if not isinstance(llm_conf, dict):
|
|
raise ValueError(f"Invalid LLM Conf: {llm_type}")
|
|
return ChatOpenAI(**llm_conf)
|
|
|
|
|
|
def get_llm_by_type(
|
|
llm_type: LLMType,
|
|
) -> ChatOpenAI:
|
|
"""
|
|
Get LLM instance by type. Returns cached instance if available.
|
|
"""
|
|
if llm_type in _llm_cache:
|
|
return _llm_cache[llm_type]
|
|
|
|
conf = load_yaml_config(
|
|
str((Path(__file__).parent.parent.parent / "conf.yaml").resolve())
|
|
)
|
|
llm = _create_llm_use_conf(llm_type, conf)
|
|
_llm_cache[llm_type] = llm
|
|
return llm
|
|
|
|
|
|
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
|
# reasoning_llm = get_llm_by_type("reasoning")
|
|
# vl_llm = get_llm_by_type("vision")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Initialize LLMs for different purposes - now these will be cached
|
|
basic_llm = get_llm_by_type("basic")
|
|
print(basic_llm.invoke("Hello"))
|