diff --git a/src/llms/llm.py b/src/llms/llm.py index 1c7157b..057f6e7 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import logging import os from pathlib import Path from typing import Any, Dict, get_args @@ -15,9 +16,57 @@ from src.config import load_yaml_config from src.config.agents import LLMType from src.llms.providers.dashscope import ChatDashscope +logger = logging.getLogger(__name__) + # Cache for LLM instances _llm_cache: dict[LLMType, BaseChatModel] = {} +# Allowed LLM configuration keys to prevent unexpected parameters from being passed +# to LLM constructors (Issue #411 - SEARCH_ENGINE warning fix) +ALLOWED_LLM_CONFIG_KEYS = { + # Common LLM configuration keys + "model", + "api_key", + "base_url", + "api_base", + "max_retries", + "timeout", + "max_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "stream", + "logprobs", + "echo", + "best_of", + "logit_bias", + "user", + "seed", + # SSL and HTTP client settings + "verify_ssl", + "http_client", + "http_async_client", + # Platform-specific keys + "platform", + "google_api_key", + # Azure-specific keys + "azure_endpoint", + "azure_deployment", + "api_version", + "azure_ad_token", + "azure_ad_token_provider", + # Dashscope/Doubao specific keys + "extra_body", + # Token limit for context compression (removed before passing to LLM) + "token_limit", + # Default headers + "default_headers", + "default_query", +} + def _get_config_file_path() -> str: """Get the path to the configuration file.""" @@ -67,6 +116,18 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod # Merge configurations, with environment variables taking precedence merged_conf = {**llm_conf, **env_conf} + # Filter out unexpected parameters to prevent LangChain warnings (Issue #411) + # This prevents configuration keys like SEARCH_ENGINE from being passed to LLM constructors + allowed_keys_lower = {k.lower() for k in ALLOWED_LLM_CONFIG_KEYS} + unexpected_keys = [key for key in merged_conf.keys() if key.lower() not in allowed_keys_lower] + for key in unexpected_keys: + removed_value = merged_conf.pop(key) + logger.warning( + f"Removed unexpected LLM configuration key '{key}'. " + f"This key is not a valid LLM parameter and may have been placed in the wrong section of conf.yaml. " + f"Valid LLM config keys include: model, api_key, base_url, max_retries, temperature, etc." + ) + # Remove unnecessary parameters when initializing the client if "token_limit" in merged_conf: merged_conf.pop("token_limit") diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index 714b7dc..f485362 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -85,3 +85,43 @@ def test_get_llm_by_type_caches(monkeypatch, dummy_conf): inst2 = llm.get_llm_by_type("basic") assert inst1 is inst2 assert called["called"] + + +def test_create_llm_filters_unexpected_keys(monkeypatch, caplog): + """Test that unexpected configuration keys like SEARCH_ENGINE are filtered out (Issue #411).""" + import logging + + # Clear any existing environment variables that might interfere + monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False) + monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False) + monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False) + + # Config with unexpected keys that should be filtered + conf_with_unexpected_keys = { + "BASIC_MODEL": { + "api_key": "test_key", + "base_url": "http://test", + "model": "gpt-4", + "SEARCH_ENGINE": {"include_domains": ["example.com"]}, # Should be filtered + "engine": "tavily", # Should be filtered + } + } + + with caplog.at_level(logging.WARNING): + result = llm._create_llm_use_conf("basic", conf_with_unexpected_keys) + + # Verify the LLM was created + assert isinstance(result, DummyChatOpenAI) + + # Verify unexpected keys were not passed to the LLM + assert "SEARCH_ENGINE" not in result.kwargs + assert "engine" not in result.kwargs + + # Verify valid keys were passed + assert result.kwargs["api_key"] == "test_key" + assert result.kwargs["base_url"] == "http://test" + assert result.kwargs["model"] == "gpt-4" + + # Verify warnings were logged + assert any("SEARCH_ENGINE" in record.message for record in caplog.records) + assert any("engine" in record.message for record in caplog.records)