mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
refactor: 重构 Chat Completions 端点,采用类型安全的 Responses API 转换
将 /v1/chat/completions 端点从 ResponseWriter 劫持模式重构为独立的 类型安全转换路径,与 Anthropic Messages 端点架构对齐: - 在 apicompat 包新增 Chat Completions 完整类型定义和双向转换器 - 新增 ForwardAsChatCompletions service 方法,走 Responses API 上游 - Handler 改为独立的账号选择/failover 循环,不再劫持 Responses handler - 提取 handleCompatErrorResponse 为 Chat Completions 和 Messages 共用 - 删除旧的 forwardChatCompletions 直传路径及相关死代码
This commit is contained in:
@@ -1,23 +1,53 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API compatibility.
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -31,516 +61,230 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Preserve original chat-completions request for upstream passthrough when needed.
|
||||
c.Set(service.OpenAIChatCompletionsBodyKey, body)
|
||||
|
||||
var chatReq map[string]any
|
||||
if err := json.Unmarshal(body, &chatReq); err != nil {
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
includeUsage := false
|
||||
if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok {
|
||||
if v, ok := streamOptions["include_usage"].(bool); ok {
|
||||
includeUsage = v
|
||||
}
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage)
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
converted, err := service.ConvertChatCompletionsToResponses(chatReq)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
convertedBody, err := json.Marshal(converted)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
stream, _ := converted["stream"].(bool)
|
||||
model, _ := converted["model"].(string)
|
||||
originalWriter := c.Writer
|
||||
writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model)
|
||||
c.Writer = writer
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody))
|
||||
c.Request.ContentLength = int64(len(convertedBody))
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
h.Responses(c)
|
||||
writer.Finalize()
|
||||
c.Writer = originalWriter
|
||||
}
|
||||
|
||||
type chatCompletionsResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
stream bool
|
||||
includeUsage bool
|
||||
buffer bytes.Buffer
|
||||
streamBuf bytes.Buffer
|
||||
state *chatCompletionStreamState
|
||||
corrector *service.CodexToolCorrector
|
||||
finalized bool
|
||||
passthrough bool
|
||||
}
|
||||
|
||||
type chatCompletionStreamState struct {
|
||||
id string
|
||||
model string
|
||||
created int64
|
||||
sentRole bool
|
||||
sawToolCall bool
|
||||
sawText bool
|
||||
toolCallIndex map[string]int
|
||||
usage map[string]any
|
||||
}
|
||||
|
||||
func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter {
|
||||
return &chatCompletionsResponseWriter{
|
||||
ResponseWriter: w,
|
||||
stream: stream,
|
||||
includeUsage: includeUsage,
|
||||
state: &chatCompletionStreamState{
|
||||
model: strings.TrimSpace(model),
|
||||
toolCallIndex: make(map[string]int),
|
||||
},
|
||||
corrector: service.NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) {
|
||||
if w.passthrough {
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
if w.stream {
|
||||
n, err := w.streamBuf.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
w.flushStreamBuffer()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if w.finalized {
|
||||
return len(data), nil
|
||||
}
|
||||
return w.buffer.Write(data)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) {
|
||||
return w.Write([]byte(s))
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) Finalize() {
|
||||
if w.finalized {
|
||||
return
|
||||
}
|
||||
w.finalized = true
|
||||
if w.passthrough {
|
||||
return
|
||||
}
|
||||
if w.stream {
|
||||
return
|
||||
}
|
||||
|
||||
body := w.buffer.Bytes()
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Del("Content-Length")
|
||||
|
||||
converted, err := service.ConvertResponsesToChatCompletion(body)
|
||||
if err != nil {
|
||||
_, _ = w.ResponseWriter.Write(body)
|
||||
return
|
||||
}
|
||||
|
||||
corrected := converted
|
||||
if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok {
|
||||
corrected = []byte(correctedStr)
|
||||
}
|
||||
|
||||
_, _ = w.ResponseWriter.Write(corrected)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) SetPassthrough() {
|
||||
w.passthrough = true
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) Status() int {
|
||||
if w.ResponseWriter == nil {
|
||||
return 0
|
||||
}
|
||||
return w.ResponseWriter.Status()
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) Written() bool {
|
||||
if w.ResponseWriter == nil {
|
||||
return false
|
||||
}
|
||||
return w.ResponseWriter.Written()
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) flushStreamBuffer() {
|
||||
for {
|
||||
buf := w.streamBuf.Bytes()
|
||||
idx := bytes.IndexByte(buf, '\n')
|
||||
if idx == -1 {
|
||||
return
|
||||
}
|
||||
lineBytes := w.streamBuf.Next(idx + 1)
|
||||
line := strings.TrimRight(string(lineBytes), "\r\n")
|
||||
w.handleStreamLine(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) handleStreamLine(line string) {
|
||||
if line == "" {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(line, ":") {
|
||||
_, _ = w.ResponseWriter.Write([]byte(line + "\n\n"))
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
return
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
for _, chunk := range w.convertResponseDataToChatChunks(data) {
|
||||
if chunk == "" {
|
||||
continue
|
||||
}
|
||||
if chunk == "[DONE]" {
|
||||
_, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
continue
|
||||
}
|
||||
_, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string {
|
||||
if data == "" {
|
||||
return nil
|
||||
}
|
||||
if data == "[DONE]" {
|
||||
return []string{"[DONE]"}
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; ok {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
eventType := strings.TrimSpace(getString(payload["type"]))
|
||||
if eventType == "" {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
w.state.applyMetadata(payload)
|
||||
|
||||
switch eventType {
|
||||
case "response.created":
|
||||
return nil
|
||||
case "response.output_text.delta":
|
||||
delta := getString(payload["delta"])
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
w.state.sawText = true
|
||||
return []string{w.buildTextDeltaChunk(delta)}
|
||||
case "response.output_text.done":
|
||||
if w.state.sawText {
|
||||
return nil
|
||||
}
|
||||
text := getString(payload["text"])
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
w.state.sawText = true
|
||||
return []string{w.buildTextDeltaChunk(text)}
|
||||
case "response.output_item.added", "response.output_item.delta":
|
||||
if item, ok := payload["item"].(map[string]any); ok {
|
||||
if callID, name, args, ok := extractToolCallFromItem(item); ok {
|
||||
w.state.sawToolCall = true
|
||||
return []string{w.buildToolCallChunk(callID, name, args)}
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
case "response.completed", "response.done":
|
||||
if responseObj, ok := payload["response"].(map[string]any); ok {
|
||||
w.state.applyResponseUsage(responseObj)
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
return []string{w.buildFinalChunk()}
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") {
|
||||
callID := strings.TrimSpace(getString(payload["call_id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(payload["tool_call_id"]))
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(payload["id"]))
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
args := getString(payload["delta"])
|
||||
name := strings.TrimSpace(getString(payload["name"]))
|
||||
if callID != "" && (args != "" || name != "") {
|
||||
w.state.sawToolCall = true
|
||||
return []string{w.buildToolCallChunk(callID, name, args)}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string {
|
||||
w.state.ensureDefaults()
|
||||
payload := map[string]any{
|
||||
"content": delta,
|
||||
}
|
||||
if !w.state.sentRole {
|
||||
payload["role"] = "assistant"
|
||||
w.state.sentRole = true
|
||||
}
|
||||
return w.buildChunk(payload, nil, nil)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string {
|
||||
w.state.ensureDefaults()
|
||||
index := w.state.toolCallIndexFor(callID)
|
||||
function := map[string]any{}
|
||||
if name != "" {
|
||||
function["name"] = name
|
||||
}
|
||||
if args != "" {
|
||||
function["arguments"] = args
|
||||
}
|
||||
toolCall := map[string]any{
|
||||
"index": index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": function,
|
||||
}
|
||||
|
||||
delta := map[string]any{
|
||||
"tool_calls": []any{toolCall},
|
||||
}
|
||||
if !w.state.sentRole {
|
||||
delta["role"] = "assistant"
|
||||
w.state.sentRole = true
|
||||
}
|
||||
|
||||
return w.buildChunk(delta, nil, nil)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildFinalChunk() string {
|
||||
w.state.ensureDefaults()
|
||||
finishReason := "stop"
|
||||
if w.state.sawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
usage := map[string]any(nil)
|
||||
if w.includeUsage && w.state.usage != nil {
|
||||
usage = w.state.usage
|
||||
}
|
||||
return w.buildChunk(map[string]any{}, finishReason, usage)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string {
|
||||
w.state.ensureDefaults()
|
||||
chunk := map[string]any{
|
||||
"id": w.state.id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": w.state.created,
|
||||
"model": w.state.model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
if usage != nil {
|
||||
chunk["usage"] = usage
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(chunk)
|
||||
if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok {
|
||||
return corrected
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) ensureDefaults() {
|
||||
if s.id == "" {
|
||||
s.id = "chatcmpl-" + randomHexUnsafe(12)
|
||||
}
|
||||
if s.model == "" {
|
||||
s.model = "unknown"
|
||||
}
|
||||
if s.created == 0 {
|
||||
s.created = time.Now().Unix()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int {
|
||||
if idx, ok := s.toolCallIndex[callID]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := len(s.toolCallIndex)
|
||||
s.toolCallIndex[callID] = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) {
|
||||
if responseObj, ok := payload["response"].(map[string]any); ok {
|
||||
s.applyResponseMetadata(responseObj)
|
||||
}
|
||||
|
||||
if s.id == "" {
|
||||
if id := strings.TrimSpace(getString(payload["response_id"])); id != "" {
|
||||
s.id = id
|
||||
} else if id := strings.TrimSpace(getString(payload["id"])); id != "" {
|
||||
s.id = id
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}
|
||||
if s.model == "" {
|
||||
if model := strings.TrimSpace(getString(payload["model"])); model != "" {
|
||||
s.model = model
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
}
|
||||
if s.created == 0 {
|
||||
if created := getInt64(payload["created_at"]); created != 0 {
|
||||
s.created = created
|
||||
} else if created := getInt64(payload["created"]); created != 0 {
|
||||
s.created = created
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) {
|
||||
if s.id == "" {
|
||||
if id := strings.TrimSpace(getString(responseObj["id"])); id != "" {
|
||||
s.id = id
|
||||
}
|
||||
}
|
||||
if s.model == "" {
|
||||
if model := strings.TrimSpace(getString(responseObj["model"])); model != "" {
|
||||
s.model = model
|
||||
}
|
||||
}
|
||||
if s.created == 0 {
|
||||
if created := getInt64(responseObj["created_at"]); created != 0 {
|
||||
s.created = created
|
||||
}
|
||||
}
|
||||
}
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) {
|
||||
usage, ok := responseObj["usage"].(map[string]any)
|
||||
if !ok {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
promptTokens := int(getNumber(usage["input_tokens"]))
|
||||
completionTokens := int(getNumber(usage["output_tokens"]))
|
||||
if promptTokens == 0 && completionTokens == 0 {
|
||||
return
|
||||
}
|
||||
s.usage = map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": completionTokens,
|
||||
"total_tokens": promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func extractToolCallFromItem(item map[string]any) (string, string, string, bool) {
|
||||
itemType := strings.TrimSpace(getString(item["type"]))
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
return "", "", "", false
|
||||
}
|
||||
callID := strings.TrimSpace(getString(item["call_id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(item["id"]))
|
||||
}
|
||||
name := strings.TrimSpace(getString(item["name"]))
|
||||
args := getString(item["arguments"])
|
||||
if fn, ok := item["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(getString(fn["name"]))
|
||||
}
|
||||
if args == "" {
|
||||
args = getString(fn["arguments"])
|
||||
}
|
||||
}
|
||||
if callID == "" && name == "" && args == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_" + randomHexUnsafe(6)
|
||||
}
|
||||
return callID, name, args, true
|
||||
}
|
||||
|
||||
func getString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
case json.Number:
|
||||
return v.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func getNumber(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
f, _ := v.Float64()
|
||||
return f
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func getInt64(value any) int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
case json.Number:
|
||||
i, _ := v.Int64()
|
||||
return i
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func randomHexUnsafe(byteLength int) string {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 8
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "000000"
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user