From 2554e4ba639879e378f0ba94d37661eaad040d4e Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Wed, 11 Jun 2025 19:46:01 +0800 Subject: [PATCH] test: add unit tests of llms (#299) --- src/llms/llm.py | 6 ---- tests/unit/llms/test_llm.py | 70 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 tests/unit/llms/test_llm.py diff --git a/src/llms/llm.py b/src/llms/llm.py index 3f31189..e88d1c1 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -70,9 +70,3 @@ def get_llm_by_type( # In the future, we will use reasoning_llm and vl_llm for different purposes # reasoning_llm = get_llm_by_type("reasoning") # vl_llm = get_llm_by_type("vision") - - -if __name__ == "__main__": - # Initialize LLMs for different purposes - now these will be cached - basic_llm = get_llm_by_type("basic") - print(basic_llm.invoke("Hello")) diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py new file mode 100644 index 0000000..2587080 --- /dev/null +++ b/tests/unit/llms/test_llm.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import os +import types +import pytest +from src.llms import llm + + +class DummyChatOpenAI: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def invoke(self, msg): + return f"Echo: {msg}" + + +@pytest.fixture(autouse=True) +def patch_chat_openai(monkeypatch): + monkeypatch.setattr(llm, "ChatOpenAI", DummyChatOpenAI) + + +@pytest.fixture +def dummy_conf(): + return { + "BASIC_MODEL": {"api_key": "test_key", "base_url": "http://test"}, + "REASONING_MODEL": {"api_key": "reason_key"}, + "VISION_MODEL": {"api_key": "vision_key"}, + } + + +def test_get_env_llm_conf(monkeypatch): + monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key") + monkeypatch.setenv("BASIC_MODEL__BASE_URL", "http://env") + conf = llm._get_env_llm_conf("basic") + assert conf["api_key"] == "env_key" + assert conf["base_url"] == "http://env" + + +def test_create_llm_use_conf_merges_env(monkeypatch, dummy_conf): + monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key") + result = llm._create_llm_use_conf("basic", dummy_conf) + assert isinstance(result, DummyChatOpenAI) + assert result.kwargs["api_key"] == "env_key" + assert result.kwargs["base_url"] == "http://test" + + +def test_create_llm_use_conf_invalid_type(dummy_conf): + with pytest.raises(ValueError): + llm._create_llm_use_conf("unknown", dummy_conf) + + +def test_create_llm_use_conf_empty_conf(monkeypatch): + with pytest.raises(ValueError): + llm._create_llm_use_conf("basic", {}) + + +def test_get_llm_by_type_caches(monkeypatch, dummy_conf): + called = {} + + def fake_load_yaml_config(path): + called["called"] = True + return dummy_conf + + monkeypatch.setattr(llm, "load_yaml_config", fake_load_yaml_config) + llm._llm_cache.clear() + inst1 = llm.get_llm_by_type("basic") + inst2 = llm.get_llm_by_type("basic") + assert inst1 is inst2 + assert called["called"]