package handler import ( "context" "errors" "net/http" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "go.uber.org/zap" ) // ChatCompletions handles OpenAI Chat Completions API requests. // POST /v1/chat/completions func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { streamStarted := false defer h.recoverResponsesPanic(c, &streamStarted) requestStart := time.Now() apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return } subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } reqLog := requestLogger( c, "handler.openai_gateway.chat_completions", zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) if !h.ensureResponsesDependencies(c, reqLog) { return } body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } if len(body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } modelResult := gjson.GetBytes(body, "model") if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } reqModel := modelResult.String() reqStream := gjson.GetBytes(body, "stream").Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } subscription, _ := middleware2.GetSubscriptionFromContext(c) service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) routingStart := time.Now() userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) if !acquired { return } if userReleaseFunc != nil { defer userReleaseFunc() } if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } sessionHash := h.gatewayService.GenerateSessionHash(c, body) promptCacheKey := h.gatewayService.ExtractSessionID(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError for { c.Set("openai_chat_completions_fallback_model", "") reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), apiKey.GroupID, "", sessionHash, reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) if err != nil { reqLog.Warn("openai_chat_completions.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { defaultModel := "" if apiKey.Group != nil { defaultModel = apiKey.Group.DefaultMappedModel } if defaultModel != "" && defaultModel != reqModel { reqLog.Info("openai_chat_completions.fallback_to_default_model", zap.String("default_mapped_model", defaultModel), ) selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), apiKey.GroupID, "", sessionHash, defaultModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) if err == nil && selection != nil { c.Set("openai_chat_completions_fallback_model", defaultModel) } } if err != nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } } else { if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) } else { h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) } return } } if selection == nil || selection.Account == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } account := selection.Account sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) _ = scheduleDecision setOpsSelectedAccount(c, account.ID, account.Platform) accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) if !acquired { return } service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() defaultMappedModel := "" if apiKey.Group != nil { defaultMappedModel = apiKey.Group.DefaultMappedModel } if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" { defaultMappedModel = fallbackModel } result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { responseLatencyMs = forwardDurationMs - upstreamLatencyMs } service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) if err == nil && result != nil && result.FirstTokenMs != nil { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) // Pool mode: retry on the same account if failoverErr.RetryableOnSameAccount { retryLimit := account.GetPoolModeRetryCount() if sameAccountRetryCount[account.ID] < retryLimit { sameAccountRetryCount[account.ID]++ reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("retry_limit", retryLimit), zap.Int("retry_count", sameAccountRetryCount[account.ID]), ) select { case <-c.Request.Context().Done(): return case <-time.After(sameAccountRetryDelay): } continue } } h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { h.handleFailoverExhausted(c, failoverErr, streamStarted) return } switchCount++ reqLog.Warn("openai_chat_completions.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("switch_count", switchCount), zap.Int("max_switches", maxAccountSwitches), ) continue } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) reqLog.Warn("openai_chat_completions.forward_failed", zap.Int64("account_id", account.ID), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), ) return } if result != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) } else { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) } userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.chat_completions"), zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", reqModel), zap.Int64("account_id", account.ID), ).Error("openai_chat_completions.record_usage_failed", zap.Error(err)) } }) reqLog.Debug("openai_chat_completions.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), ) return } }