Files
sub2api/backend/internal/handler/openai_gateway_handler.go

1074 lines
37 KiB
Go
Raw Normal View History

2025-12-22 22:58:31 +08:00
package handler
import (
"context"
"encoding/json"
2025-12-27 11:44:00 +08:00
"errors"
2025-12-22 22:58:31 +08:00
"fmt"
"net/http"
"runtime/debug"
"strconv"
2026-01-10 21:57:57 +08:00
"strings"
2025-12-22 22:58:31 +08:00
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
2025-12-26 10:42:08 +08:00
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/service"
2025-12-22 22:58:31 +08:00
coderws "github.com/coder/websocket"
2025-12-22 22:58:31 +08:00
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
2025-12-22 22:58:31 +08:00
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
2025-12-22 22:58:31 +08:00
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
2025-12-22 22:58:31 +08:00
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
}
2025-12-22 22:58:31 +08:00
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
2025-12-22 22:58:31 +08:00
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
setOpenAIClientTransportHTTP(c)
requestStart := time.Now()
2025-12-22 22:58:31 +08:00
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
2025-12-22 22:58:31 +08:00
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
2025-12-22 22:58:31 +08:00
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
2025-12-22 22:58:31 +08:00
// Read request body
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
2025-12-22 22:58:31 +08:00
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
2025-12-22 22:58:31 +08:00
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
2025-12-22 22:58:31 +08:00
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
streamResult := gjson.GetBytes(body, "stream")
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
if previousResponseID != "" {
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
reqLog = reqLog.With(
zap.Bool("has_previous_response_id", true),
zap.String("previous_response_id_kind", previousResponseIDKind),
zap.Int("previous_response_id_len", len(previousResponseID)),
)
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "previous_response_id_looks_like_message_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
}
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
}
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
2025-12-22 22:58:31 +08:00
// Get subscription info (may be nil)
2025-12-26 10:42:08 +08:00
subscription, _ := middleware2.GetSubscriptionFromContext(c)
2025-12-22 22:58:31 +08:00
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
2025-12-22 22:58:31 +08:00
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
2025-12-22 22:58:31 +08:00
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
2025-12-22 22:58:31 +08:00
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
2025-12-22 22:58:31 +08:00
return
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
2025-12-22 22:58:31 +08:00
maxAccountSwitches := h.maxAccountSwitches
2025-12-27 11:44:00 +08:00
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
2025-12-22 22:58:31 +08:00
2025-12-27 11:44:00 +08:00
for {
// Select account supporting the requested model
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
2025-12-27 11:44:00 +08:00
if err != nil {
reqLog.Warn("openai.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
2025-12-27 11:44:00 +08:00
if len(failedAccountIDs) == 0 {
fix(audit): 第二批审计修复 — P0 生产 Bug、安全加固、性能优化、缓存一致性、代码质量 基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题: P0 生产 Bug: - 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03) - generateRandomID 改用 crypto/rand 替代固定索引 (P0-04) - IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05) 安全加固: - gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14) - usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16) - 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15) 性能优化: - gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02) - group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03) - usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07) - GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10) - ip.go CIDR 预编译为包级变量 (P1-11) - BatchUpdateCredentials 重构为先验证后更新 (P1-13) 缓存一致性: - billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10) - DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12) - UserService.UpdateBalance 成功后异步失效 billingCache (P2-13) 代码质量: - search 截断改为按 rune 处理,支持多字节字符 (P2-01) - TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07) - CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16) 测试覆盖: - 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例) - 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例) - 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go - 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试 - config_test.go 补充 server.mode/sslmode 默认值断言 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 19:46:42 +08:00
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
2025-12-27 11:44:00 +08:00
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
2025-12-27 11:44:00 +08:00
return
}
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
if previousResponseID != "" && selection != nil && selection.Account != nil {
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
}
reqLog.Debug("openai.account_schedule_decision",
zap.String("layer", scheduleDecision.Layer),
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
zap.Int("top_k", scheduleDecision.TopK),
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
zap.Float64("load_skew", scheduleDecision.LoadSkew),
)
account := selection.Account
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
2025-12-22 22:58:31 +08:00
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
2025-12-27 11:44:00 +08:00
}
2025-12-22 22:58:31 +08:00
2025-12-27 11:44:00 +08:00
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
2025-12-27 11:44:00 +08:00
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
2025-12-27 11:44:00 +08:00
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
2025-12-27 11:44:00 +08:00
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
h.gatewayService.RecordOpenAIAccountSwitch()
2025-12-27 11:44:00 +08:00
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
2025-12-27 11:44:00 +08:00
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
2025-12-27 11:44:00 +08:00
return
}
switchCount++
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
2025-12-27 11:44:00 +08:00
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
return
}
reqLog.Error("openai.forward_failed", fields...)
2025-12-27 11:44:00 +08:00
return
2025-12-22 22:58:31 +08:00
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
2025-12-27 11:44:00 +08:00
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
2025-12-27 11:44:00 +08:00
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
2025-12-27 11:44:00 +08:00
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai.record_usage_failed", zap.Error(err))
2025-12-27 11:44:00 +08:00
}
})
reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
2025-12-27 11:44:00 +08:00
return
}
2025-12-22 22:58:31 +08:00
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
return true
}
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
validation := service.ValidateFunctionCallOutputContext(reqBody)
if !validation.HasFunctionCallOutput {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
return true
}
if validation.HasFunctionCallOutputMissingCallID {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return false
}
if validation.HasItemReferenceForAllCallIDs {
return true
}
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return false
}
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
c *gin.Context,
userID int64,
userConcurrency int,
reqStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), bool) {
ctx := c.Request.Context()
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", *streamStarted)
return nil, false
}
if userAcquired {
return wrapReleaseOnDone(ctx, userReleaseFunc), true
}
maxWait := service.CalculateMaxWait(userConcurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
if waitErr != nil {
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return nil, false
}
waitCounted := waitErr == nil && canWait
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
}
}()
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
h.handleConcurrencyError(c, err, "user", *streamStarted)
return nil, false
}
// 槽位获取成功后,立刻退出等待计数。
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
waitCounted = false
}
return wrapReleaseOnDone(ctx, userReleaseFunc), true
}
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
c *gin.Context,
groupID *int64,
sessionHash string,
selection *service.AccountSelectionResult,
reqStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), bool) {
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
ctx := c.Request.Context()
account := selection.Account
if selection.Acquired {
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
}
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", *streamStarted)
return nil, false
}
if fastAcquired {
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
}
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
if waitErr != nil {
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
} else if !canWait {
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
return nil, false
}
accountWaitCounted := waitErr == nil && canWait
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
accountWaitCounted = false
}
}
defer releaseWait()
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
streamStarted,
)
if err != nil {
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", *streamStarted)
return nil, false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
}
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
// GET /openai/v1/responses (Upgrade: websocket)
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
if !isOpenAIWSUpgradeRequest(c.Request) {
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
return
}
setOpenAIClientTransportWS(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses_ws",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.Bool("openai_ws_mode", true),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
reqLog.Info("openai.websocket_ingress_started")
clientIP := ip.GetClientIP(c)
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
reqLog.Warn("openai.websocket_accept_failed",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("request_user_agent", userAgent),
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
)
return
}
defer func() {
_ = wsConn.CloseNow()
}()
wsConn.SetReadLimit(16 * 1024 * 1024)
ctx := c.Request.Context()
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
msgType, firstMessage, err := wsConn.Read(readCtx)
cancel()
if err != nil {
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_read_first_message_failed",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
zap.Duration("read_timeout", 30*time.Second),
)
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
return
}
if !gjson.ValidBytes(firstMessage) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
return
}
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
if reqModel == "" {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
return
}
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
return
}
reqLog = reqLog.With(
zap.Bool("ws_ingress", true),
zap.String("model", reqModel),
zap.Bool("has_previous_response_id", previousResponseID != ""),
zap.String("previous_response_id_kind", previousResponseIDKind),
)
setOpsRequestContext(c, reqModel, true, firstMessage)
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
if currentAccountRelease != nil {
currentAccountRelease()
currentAccountRelease = nil
}
if currentUserRelease != nil {
currentUserRelease()
currentUserRelease = nil
}
}
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
defer releaseTurnSlots()
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
return
}
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
c,
firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
)
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
return
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil || result == nil {
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
}
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
recovered := recover()
if recovered == nil {
return
}
started := false
if streamStarted != nil {
started = *streamStarted
}
wroteFallback := h.ensureForwardErrorResponse(c, started)
requestLogger(c, "handler.openai_gateway.responses").Error(
"openai.responses_panic_recovered",
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Any("panic", recovered),
zap.ByteString("stack", debug.Stack()),
)
}
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
missing := h.missingResponsesDependencies()
if len(missing) == 0 {
return true
}
if reqLog == nil {
reqLog = requestLogger(c, "handler.openai_gateway.responses")
}
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
if c != nil && c.Writer != nil && !c.Writer.Written() {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Service temporarily unavailable",
},
})
}
return false
}
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
missing := make([]string, 0, 5)
if h == nil {
return append(missing, "handler")
}
if h.gatewayService == nil {
missing = append(missing, "gatewayService")
}
if h.billingCacheService == nil {
missing = append(missing, "billingCacheService")
}
if h.apiKeyService == nil {
missing = append(missing, "apiKeyService")
}
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
missing = append(missing, "concurrencyHelper")
}
return missing
}
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
}
v, ok := c.Get(key)
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case float64:
return int64(t), true
default:
return 0, false
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Any("panic", recovered),
).Error("openai.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
2025-12-22 22:58:31 +08:00
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
2025-12-27 11:44:00 +08:00
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
2025-12-22 22:58:31 +08:00
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// SSE 错误事件固定 schema使用 Quote 直拼可避免额外 Marshal 分配。
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
2025-12-22 22:58:31 +08:00
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
if wroteFallback {
return false
}
if c == nil || c.Writer == nil {
return false
}
return c.Writer.Written()
}
2025-12-22 22:58:31 +08:00
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func setOpenAIClientTransportHTTP(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
}
func setOpenAIClientTransportWS(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
}
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
gid := int64(0)
if groupID != nil {
gid = *groupID
}
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
}
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
return false
}
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
}
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
if conn == nil {
return
}
reason = strings.TrimSpace(reason)
if len(reason) > 120 {
reason = reason[:120]
}
_ = conn.Close(status, reason)
_ = conn.CloseNow()
}
func summarizeWSCloseErrorForLog(err error) (string, string) {
if err == nil {
return "-", "-"
}
statusCode := coderws.CloseStatus(err)
if statusCode == -1 {
return "-", "-"
}
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
closeReason := "-"
var closeErr coderws.CloseError
if errors.As(err, &closeErr) {
reason := strings.TrimSpace(closeErr.Reason)
if reason != "" {
closeReason = reason
}
}
return closeStatus, closeReason
}