diff --git a/docs/configuration_guide.md b/docs/configuration_guide.md index 610e442..d21ec62 100644 --- a/docs/configuration_guide.md +++ b/docs/configuration_guide.md @@ -105,16 +105,18 @@ BASIC_MODEL: Note: The available models and their exact names may change over time. Please verify the currently available models and their correct identifiers in [OpenRouter's official documentation](https://openrouter.ai/docs). -### How to use Azure models? -DeerFlow supports the integration of Azure models. You can refer to [litellm Azure](https://docs.litellm.ai/docs/providers/azure). Configuration example of `conf.yaml`: +### How to use Azure OpenAI chat models? + +DeerFlow supports the integration of Azure OpenAI chat models. You can refer to [AzureChatOpenAI](https://python.langchain.com/api_reference/openai/chat_models/langchain_openai.chat_models.azure.AzureChatOpenAI.html). Configuration example of `conf.yaml`: ```yaml BASIC_MODEL: model: "azure/gpt-4o-2024-08-06" - api_base: $AZURE_API_BASE - api_version: $AZURE_API_VERSION - api_key: $AZURE_API_KEY + azure_endpoint: $AZURE_OPENAI_ENDPOINT + api_version: $OPENAI_API_VERSION + api_key: $AZURE_OPENAI_API_KEY ``` + ## About Search Engine ### How to control search domains for Tavily? @@ -136,4 +138,5 @@ SEARCH_ENGINE: # Exclude results from these domains (blacklist) exclude_domains: - unreliable-site.com - - spam-domain.net \ No newline at end of file + - spam-domain.net + diff --git a/src/llms/llm.py b/src/llms/llm.py index aa78961..f91e9c6 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -6,7 +6,8 @@ from typing import Any, Dict import os import httpx -from langchain_openai import ChatOpenAI +from langchain_core.language_models import BaseChatModel +from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_deepseek import ChatDeepSeek from typing import get_args @@ -14,7 +15,7 @@ from src.config import load_yaml_config from src.config.agents import LLMType # Cache for LLM instances -_llm_cache: dict[LLMType, ChatOpenAI] = {} +_llm_cache: dict[LLMType, BaseChatModel] = {} def _get_config_file_path() -> str: @@ -48,7 +49,7 @@ def _get_env_llm_conf(llm_type: str) -> Dict[str, Any]: def _create_llm_use_conf( llm_type: LLMType, conf: Dict[str, Any] -) -> ChatOpenAI | ChatDeepSeek: +) -> BaseChatModel : """Create LLM instance using configuration.""" llm_type_config_keys = _get_llm_type_config_keys() config_key = llm_type_config_keys.get(llm_type) @@ -86,16 +87,17 @@ def _create_llm_use_conf( merged_conf["http_client"] = http_client merged_conf["http_async_client"] = http_async_client - return ( - ChatOpenAI(**merged_conf) - if llm_type != "reasoning" - else ChatDeepSeek(**merged_conf) - ) - + if "azure_endpoint" in merged_conf or os.getenv("AZURE_OPENAI_ENDPOINT"): + return AzureChatOpenAI(**merged_conf) + if llm_type == "reasoning": + return ChatDeepSeek(**merged_conf) + else + return ChatOpenAI(**merged_conf) + def get_llm_by_type( llm_type: LLMType, -) -> ChatOpenAI: +) -> BaseChatModel: """ Get LLM instance by type. Returns cached instance if available. """