mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-04 06:32:13 +08:00
71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import os
|
|
import types
|
|
import pytest
|
|
from src.llms import llm
|
|
|
|
|
|
class DummyChatOpenAI:
|
|
def __init__(self, **kwargs):
|
|
self.kwargs = kwargs
|
|
|
|
def invoke(self, msg):
|
|
return f"Echo: {msg}"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_chat_openai(monkeypatch):
|
|
monkeypatch.setattr(llm, "ChatOpenAI", DummyChatOpenAI)
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_conf():
|
|
return {
|
|
"BASIC_MODEL": {"api_key": "test_key", "base_url": "http://test"},
|
|
"REASONING_MODEL": {"api_key": "reason_key"},
|
|
"VISION_MODEL": {"api_key": "vision_key"},
|
|
}
|
|
|
|
|
|
def test_get_env_llm_conf(monkeypatch):
|
|
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
|
|
monkeypatch.setenv("BASIC_MODEL__BASE_URL", "http://env")
|
|
conf = llm._get_env_llm_conf("basic")
|
|
assert conf["api_key"] == "env_key"
|
|
assert conf["base_url"] == "http://env"
|
|
|
|
|
|
def test_create_llm_use_conf_merges_env(monkeypatch, dummy_conf):
|
|
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
|
|
result = llm._create_llm_use_conf("basic", dummy_conf)
|
|
assert isinstance(result, DummyChatOpenAI)
|
|
assert result.kwargs["api_key"] == "env_key"
|
|
assert result.kwargs["base_url"] == "http://test"
|
|
|
|
|
|
def test_create_llm_use_conf_invalid_type(dummy_conf):
|
|
with pytest.raises(ValueError):
|
|
llm._create_llm_use_conf("unknown", dummy_conf)
|
|
|
|
|
|
def test_create_llm_use_conf_empty_conf(monkeypatch):
|
|
with pytest.raises(ValueError):
|
|
llm._create_llm_use_conf("basic", {})
|
|
|
|
|
|
def test_get_llm_by_type_caches(monkeypatch, dummy_conf):
|
|
called = {}
|
|
|
|
def fake_load_yaml_config(path):
|
|
called["called"] = True
|
|
return dummy_conf
|
|
|
|
monkeypatch.setattr(llm, "load_yaml_config", fake_load_yaml_config)
|
|
llm._llm_cache.clear()
|
|
inst1 = llm.get_llm_by_type("basic")
|
|
inst2 = llm.get_llm_by_type("basic")
|
|
assert inst1 is inst2
|
|
assert called["called"]
|