feat: Add llms to support the latest Open Source SOTA models (#497)

* fix: update README and configuration guide for new model support and reasoning capabilities

* fix: format code for consistency in agent and node files

* fix: update test cases for environment variable handling in llm configuration

* fix: refactor message chunk conversion functions for improved clarity and maintainability

* refactor: remove enable_thinking parameter from LLM configuration functions

* chore: update agent-LLM mapping for consistency

* chore: update LLM configuration handling for improved clarity

* test: add unit tests for Dashscope message chunk conversion and LLM configuration

* test: add unit tests for message chunk conversion in Dashscope

* test: add unit tests for message chunk conversion in Dashscope

* chore: remove unused imports from test_dashscope.py

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
CHANGXUBO
2025-08-13 22:29:22 +08:00
committed by GitHub
parent ea17e82514
commit d65b8f8fcc
6 changed files with 684 additions and 9 deletions

View File

@@ -4,7 +4,7 @@
from typing import Literal
# Define available LLM types
LLMType = Literal["basic", "reasoning", "vision"]
LLMType = Literal["basic", "reasoning", "vision", "code"]
# Define agent-LLM mapping
AGENT_LLM_MAP: dict[str, LLMType] = {

View File

@@ -13,6 +13,7 @@ from typing import get_args
from src.config import load_yaml_config
from src.config.agents import LLMType
from src.llms.providers.dashscope import ChatDashscope
# Cache for LLM instances
_llm_cache: dict[LLMType, BaseChatModel] = {}
@@ -29,6 +30,7 @@ def _get_llm_type_config_keys() -> dict[str, str]:
"reasoning": "REASONING_MODEL",
"basic": "BASIC_MODEL",
"vision": "VISION_MODEL",
"code": "CODE_MODEL",
}
@@ -72,9 +74,6 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod
if "max_retries" not in merged_conf:
merged_conf["max_retries"] = 3
if llm_type == "reasoning":
merged_conf["api_base"] = merged_conf.pop("base_url", None)
# Handle SSL verification settings
verify_ssl = merged_conf.pop("verify_ssl", True)
@@ -87,15 +86,23 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod
if "azure_endpoint" in merged_conf or os.getenv("AZURE_OPENAI_ENDPOINT"):
return AzureChatOpenAI(**merged_conf)
# Check if base_url is dashscope endpoint
if "base_url" in merged_conf and "dashscope." in merged_conf["base_url"]:
if llm_type == "reasoning":
merged_conf["extra_body"] = {"enable_thinking": True}
else:
merged_conf["extra_body"] = {"enable_thinking": False}
return ChatDashscope(**merged_conf)
if llm_type == "reasoning":
merged_conf["api_base"] = merged_conf.pop("base_url", None)
return ChatDeepSeek(**merged_conf)
else:
return ChatOpenAI(**merged_conf)
def get_llm_by_type(
llm_type: LLMType,
) -> BaseChatModel:
def get_llm_by_type(llm_type: LLMType) -> BaseChatModel:
"""
Get LLM instance by type. Returns cached instance if available.
"""

View File

@@ -0,0 +1,321 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
# Standard library imports
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union, cast
# Third-party imports
import openai
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_create_usage_metadata,
_handle_openai_bad_request,
warnings,
)
def _convert_delta_to_message_chunk(
delta_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
"""Convert a delta dictionary to a message chunk.
Args:
delta_dict: Dictionary containing delta information from OpenAI response
default_class: Default message chunk class to use if role is not specified
Returns:
BaseMessageChunk: Appropriate message chunk based on role and content
Raises:
KeyError: If required keys are missing from the delta dictionary
"""
message_id = delta_dict.get("id")
role = cast(str, delta_dict.get("role", ""))
content = cast(str, delta_dict.get("content") or "")
additional_kwargs: Dict[str, Any] = {}
# Handle function calls
if function_call_data := delta_dict.get("function_call"):
function_call = dict(function_call_data)
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
# Handle tool calls
tool_call_chunks = []
if raw_tool_calls := delta_dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc.get("function", {}).get("name"),
args=rtc.get("function", {}).get("arguments"),
id=rtc.get("id"),
index=rtc.get("index", 0),
)
for rtc in raw_tool_calls
if rtc.get("function") # Ensure function key exists
]
except (KeyError, TypeError):
# Log the error but continue processing
pass
# Return appropriate message chunk based on role
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=message_id)
elif role == "assistant" or default_class == AIMessageChunk:
# Handle reasoning content for OpenAI reasoning models
if reasoning_content := delta_dict.get("reasoning_content"):
additional_kwargs["reasoning_content"] = reasoning_content
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=message_id,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
elif role in ("system", "developer") or default_class == SystemMessageChunk:
if role == "developer":
additional_kwargs = {"__openai_role__": "developer"}
return SystemMessageChunk(
content=content, id=message_id, additional_kwargs=additional_kwargs
)
elif role == "function" or default_class == FunctionMessageChunk:
function_name = delta_dict.get("name", "")
return FunctionMessageChunk(content=content, name=function_name, id=message_id)
elif role == "tool" or default_class == ToolMessageChunk:
tool_call_id = delta_dict.get("tool_call_id", "")
return ToolMessageChunk(
content=content, tool_call_id=tool_call_id, id=message_id
)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=message_id)
else:
return default_class(content=content, id=message_id) # type: ignore
def _convert_chunk_to_generation_chunk(
chunk: Dict[str, Any],
default_chunk_class: Type[BaseMessageChunk],
base_generation_info: Optional[Dict[str, Any]],
) -> Optional[ChatGenerationChunk]:
"""Convert a streaming chunk to a generation chunk.
Args:
chunk: Raw chunk data from OpenAI streaming response
default_chunk_class: Default message chunk class to use
base_generation_info: Base generation information to include
Returns:
Optional[ChatGenerationChunk]: Generated chunk or None if chunk should be skipped
"""
# Skip content.delta type chunks from beta.chat.completions.stream
if chunk.get("type") == "content.delta":
return None
token_usage = chunk.get("usage")
choices = (
chunk.get("choices", [])
# Handle chunks from beta.chat.completions.stream format
or chunk.get("chunk", {}).get("choices", [])
)
usage_metadata: Optional[UsageMetadata] = (
_create_usage_metadata(token_usage) if token_usage else None
)
# Handle empty choices
if not choices:
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata)
)
return generation_chunk
choice = choices[0]
if choice.get("delta") is None:
return None
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = dict(base_generation_info) if base_generation_info else {}
# Add finish reason and model info if available
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
# Add log probabilities if available
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
# Attach usage metadata to AI message chunks
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
return generation_chunk
class ChatDashscope(ChatOpenAI):
"""Extended ChatOpenAI model with reasoning capabilities.
This class extends the base ChatOpenAI model to support OpenAI's reasoning models
that include reasoning_content in their responses. It handles the extraction and
preservation of reasoning content during both streaming and non-streaming operations.
"""
def _create_chat_result(
self,
response: Union[Dict[str, Any], openai.BaseModel],
generation_info: Optional[Dict[str, Any]] = None,
) -> ChatResult:
"""Create a chat result from the OpenAI response.
Args:
response: The response from OpenAI API
generation_info: Additional generation information
Returns:
ChatResult: The formatted chat result with reasoning content if available
"""
chat_result = super()._create_chat_result(response, generation_info)
# Only process BaseModel responses (not raw dict responses)
if not isinstance(response, openai.BaseModel):
return chat_result
# Extract reasoning content if available
try:
if (
hasattr(response, "choices")
and response.choices
and hasattr(response.choices[0], "message")
and hasattr(response.choices[0].message, "reasoning_content")
):
reasoning_content = response.choices[0].message.reasoning_content
if reasoning_content and chat_result.generations:
chat_result.generations[0].message.additional_kwargs[
"reasoning_content"
] = reasoning_content
except (IndexError, AttributeError):
# If reasoning content extraction fails, continue without it
pass
return chat_result
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Create a streaming generator for chat completions.
Args:
messages: List of messages to send to the model
stop: Optional list of stop sequences
run_manager: Optional callback manager for LLM runs
**kwargs: Additional keyword arguments for the API call
Yields:
ChatGenerationChunk: Individual chunks from the streaming response
Raises:
openai.BadRequestError: If the API request is invalid
"""
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info: Dict[str, Any] = {}
# Handle response format for beta completions
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
context_manager = response_stream
else:
# Handle regular streaming with optional response headers
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
try:
with context_manager as response:
is_first_chunk = True
for chunk in response:
# Convert chunk to dict if it's a model object
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
generation_chunk = _convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
)
if generation_chunk is None:
continue
# Update default chunk class for subsequent chunks
default_chunk_class = generation_chunk.message.__class__
# Handle log probabilities for callback
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
)
is_first_chunk = False
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
# Handle final completion for response_format requests
if hasattr(response, "get_final_completion") and "response_format" in payload:
try:
final_completion = response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
except AttributeError:
# If get_final_completion method doesn't exist, continue without it
pass