2025-12-22 22:58:31 +08:00
package handler
import (
"context"
"encoding/json"
2025-12-27 11:44:00 +08:00
"errors"
2025-12-22 22:58:31 +08:00
"fmt"
"io"
"log"
"net/http"
2026-01-10 21:57:57 +08:00
"strings"
2025-12-22 22:58:31 +08:00
"time"
2026-01-04 19:49:59 +08:00
"github.com/Wei-Shaw/sub2api/internal/config"
2026-01-12 20:44:38 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
2025-12-26 10:42:08 +08:00
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/service"
2025-12-22 22:58:31 +08:00
"github.com/gin-gonic/gin"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
2026-02-05 21:52:54 +08:00
gatewayService * service . OpenAIGatewayService
billingCacheService * service . BillingCacheService
apiKeyService * service . APIKeyService
errorPassthroughService * service . ErrorPassthroughService
concurrencyHelper * ConcurrencyHelper
maxAccountSwitches int
2025-12-22 22:58:31 +08:00
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler (
gatewayService * service . OpenAIGatewayService ,
concurrencyService * service . ConcurrencyService ,
billingCacheService * service . BillingCacheService ,
2026-02-03 19:01:49 +08:00
apiKeyService * service . APIKeyService ,
2026-02-05 21:52:54 +08:00
errorPassthroughService * service . ErrorPassthroughService ,
2026-01-04 19:49:59 +08:00
cfg * config . Config ,
2025-12-22 22:58:31 +08:00
) * OpenAIGatewayHandler {
2026-01-04 19:49:59 +08:00
pingInterval := time . Duration ( 0 )
2026-01-16 20:18:30 +08:00
maxAccountSwitches := 3
2026-01-04 19:49:59 +08:00
if cfg != nil {
pingInterval = time . Duration ( cfg . Concurrency . PingInterval ) * time . Second
2026-01-16 20:18:30 +08:00
if cfg . Gateway . MaxAccountSwitches > 0 {
maxAccountSwitches = cfg . Gateway . MaxAccountSwitches
}
2026-01-04 19:49:59 +08:00
}
2025-12-22 22:58:31 +08:00
return & OpenAIGatewayHandler {
2026-02-05 21:52:54 +08:00
gatewayService : gatewayService ,
billingCacheService : billingCacheService ,
apiKeyService : apiKeyService ,
errorPassthroughService : errorPassthroughService ,
concurrencyHelper : NewConcurrencyHelper ( concurrencyService , SSEPingFormatComment , pingInterval ) ,
maxAccountSwitches : maxAccountSwitches ,
2025-12-22 22:58:31 +08:00
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func ( h * OpenAIGatewayHandler ) Responses ( c * gin . Context ) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
2026-01-04 19:27:53 +08:00
apiKey , ok := middleware2 . GetAPIKeyFromContext ( c )
2025-12-22 22:58:31 +08:00
if ! ok {
h . errorResponse ( c , http . StatusUnauthorized , "authentication_error" , "Invalid API key" )
return
}
2025-12-26 15:40:24 +08:00
subject , ok := middleware2 . GetAuthSubjectFromContext ( c )
2025-12-22 22:58:31 +08:00
if ! ok {
h . errorResponse ( c , http . StatusInternalServerError , "api_error" , "User context not found" )
return
}
// Read request body
body , err := io . ReadAll ( c . Request . Body )
if err != nil {
2025-12-31 08:50:12 +08:00
if maxErr , ok := extractMaxBytesError ( err ) ; ok {
h . errorResponse ( c , http . StatusRequestEntityTooLarge , "invalid_request_error" , buildBodyTooLargeMessage ( maxErr . Limit ) )
return
}
2025-12-22 22:58:31 +08:00
h . errorResponse ( c , http . StatusBadRequest , "invalid_request_error" , "Failed to read request body" )
return
}
if len ( body ) == 0 {
h . errorResponse ( c , http . StatusBadRequest , "invalid_request_error" , "Request body is empty" )
return
}
2026-01-09 20:56:37 +08:00
setOpsRequestContext ( c , "" , false , body )
2025-12-22 22:58:31 +08:00
// Parse request body to map for potential modification
var reqBody map [ string ] any
if err := json . Unmarshal ( body , & reqBody ) ; err != nil {
h . errorResponse ( c , http . StatusBadRequest , "invalid_request_error" , "Failed to parse request body" )
return
}
// Extract model and stream
reqModel , _ := reqBody [ "model" ] . ( string )
reqStream , _ := reqBody [ "stream" ] . ( bool )
2025-12-31 16:17:45 +08:00
// 验证 model 必填
if reqModel == "" {
h . errorResponse ( c , http . StatusBadRequest , "invalid_request_error" , "model is required" )
return
}
2025-12-22 22:58:31 +08:00
userAgent := c . GetHeader ( "User-Agent" )
if ! openai . IsCodexCLIRequest ( userAgent ) {
2026-01-10 21:57:57 +08:00
existingInstructions , _ := reqBody [ "instructions" ] . ( string )
if strings . TrimSpace ( existingInstructions ) == "" {
if instructions := strings . TrimSpace ( service . GetOpenCodeInstructions ( ) ) ; instructions != "" {
reqBody [ "instructions" ] = instructions
// Re-serialize body
body , err = json . Marshal ( reqBody )
if err != nil {
h . errorResponse ( c , http . StatusInternalServerError , "api_error" , "Failed to process request" )
return
}
}
2025-12-22 22:58:31 +08:00
}
}
2026-01-09 20:56:37 +08:00
setOpsRequestContext ( c , reqModel , reqStream , body )
2026-01-13 16:47:35 +08:00
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id, 或 input 内存在带 call_id 的 tool_call/function_call,
// 或带 id 且与 call_id 匹配的 item_reference。
if service . HasFunctionCallOutput ( reqBody ) {
previousResponseID , _ := reqBody [ "previous_response_id" ] . ( string )
if strings . TrimSpace ( previousResponseID ) == "" && ! service . HasToolCallContext ( reqBody ) {
if service . HasFunctionCallOutputMissingCallID ( reqBody ) {
log . Printf ( "[OpenAI Handler] function_call_output 缺少 call_id: model=%s" , reqModel )
h . errorResponse ( c , http . StatusBadRequest , "invalid_request_error" , "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id" )
return
}
callIDs := service . FunctionCallOutputCallIDs ( reqBody )
if ! service . HasItemReferenceForCallIDs ( reqBody , callIDs ) {
log . Printf ( "[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s" , reqModel )
h . errorResponse ( c , http . StatusBadRequest , "invalid_request_error" , "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id" )
return
}
}
}
2025-12-22 22:58:31 +08:00
// Track if we've started streaming (for error handling)
streamStarted := false
// Get subscription info (may be nil)
2025-12-26 10:42:08 +08:00
subscription , _ := middleware2 . GetSubscriptionFromContext ( c )
2025-12-22 22:58:31 +08:00
// 0. Check if wait queue is full
2025-12-26 15:40:24 +08:00
maxWait := service . CalculateMaxWait ( subject . Concurrency )
canWait , err := h . concurrencyHelper . IncrementWaitCount ( c . Request . Context ( ) , subject . UserID , maxWait )
2026-01-09 20:56:37 +08:00
waitCounted := false
2025-12-22 22:58:31 +08:00
if err != nil {
log . Printf ( "Increment wait count failed: %v" , err )
// On error, allow request to proceed
} else if ! canWait {
h . errorResponse ( c , http . StatusTooManyRequests , "rate_limit_error" , "Too many pending requests, please retry later" )
return
}
2026-01-09 20:56:37 +08:00
if err == nil && canWait {
waitCounted = true
}
defer func ( ) {
if waitCounted {
h . concurrencyHelper . DecrementWaitCount ( c . Request . Context ( ) , subject . UserID )
}
} ( )
2025-12-22 22:58:31 +08:00
// 1. First acquire user concurrency slot
2025-12-26 15:40:24 +08:00
userReleaseFunc , err := h . concurrencyHelper . AcquireUserSlotWithWait ( c , subject . UserID , subject . Concurrency , reqStream , & streamStarted )
2025-12-22 22:58:31 +08:00
if err != nil {
log . Printf ( "User concurrency acquire failed: %v" , err )
h . handleConcurrencyError ( c , err , "user" , streamStarted )
return
}
2026-01-09 20:56:37 +08:00
// User slot acquired: no longer waiting.
if waitCounted {
h . concurrencyHelper . DecrementWaitCount ( c . Request . Context ( ) , subject . UserID )
waitCounted = false
}
2026-01-04 19:49:59 +08:00
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone ( c . Request . Context ( ) , userReleaseFunc )
2025-12-22 22:58:31 +08:00
if userReleaseFunc != nil {
defer userReleaseFunc ( )
}
// 2. Re-check billing eligibility after wait
2025-12-26 15:40:24 +08:00
if err := h . billingCacheService . CheckBillingEligibility ( c . Request . Context ( ) , apiKey . User , apiKey , apiKey . Group , subscription ) ; err != nil {
2025-12-22 22:58:31 +08:00
log . Printf ( "Billing eligibility check failed after wait: %v" , err )
2026-01-02 17:40:57 +08:00
status , code , message := billingErrorDetails ( err )
h . handleStreamingAwareError ( c , status , code , message , streamStarted )
2025-12-22 22:58:31 +08:00
return
}
2026-01-17 02:31:16 +08:00
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h . gatewayService . GenerateSessionHash ( c , reqBody )
2025-12-22 22:58:31 +08:00
2026-01-16 20:18:30 +08:00
maxAccountSwitches := h . maxAccountSwitches
2025-12-27 11:44:00 +08:00
switchCount := 0
failedAccountIDs := make ( map [ int64 ] struct { } )
2026-02-05 21:52:54 +08:00
var lastFailoverErr * service . UpstreamFailoverError
2025-12-22 22:58:31 +08:00
2025-12-27 11:44:00 +08:00
for {
// Select account supporting the requested model
log . Printf ( "[OpenAI Handler] Selecting account: groupID=%v model=%s" , apiKey . GroupID , reqModel )
2026-01-01 04:01:51 +08:00
selection , err := h . gatewayService . SelectAccountWithLoadAwareness ( c . Request . Context ( ) , apiKey . GroupID , sessionHash , reqModel , failedAccountIDs )
2025-12-27 11:44:00 +08:00
if err != nil {
log . Printf ( "[OpenAI Handler] SelectAccount failed: %v" , err )
if len ( failedAccountIDs ) == 0 {
h . handleStreamingAwareError ( c , http . StatusServiceUnavailable , "api_error" , "No available accounts: " + err . Error ( ) , streamStarted )
return
}
2026-02-05 21:52:54 +08:00
if lastFailoverErr != nil {
h . handleFailoverExhausted ( c , lastFailoverErr , streamStarted )
} else {
h . handleFailoverExhaustedSimple ( c , 502 , streamStarted )
}
2025-12-27 11:44:00 +08:00
return
}
2026-01-01 04:01:51 +08:00
account := selection . Account
2025-12-27 11:44:00 +08:00
log . Printf ( "[OpenAI Handler] Selected account: id=%d name=%s" , account . ID , account . Name )
2026-01-09 20:56:37 +08:00
setOpsSelectedAccount ( c , account . ID )
2025-12-22 22:58:31 +08:00
2025-12-27 11:44:00 +08:00
// 3. Acquire account concurrency slot
2026-01-01 04:01:51 +08:00
accountReleaseFunc := selection . ReleaseFunc
if ! selection . Acquired {
if selection . WaitPlan == nil {
h . handleStreamingAwareError ( c , http . StatusServiceUnavailable , "api_error" , "No available accounts" , streamStarted )
return
}
2026-01-09 20:56:37 +08:00
accountWaitCounted := false
2026-01-01 04:01:51 +08:00
canWait , err := h . concurrencyHelper . IncrementAccountWaitCount ( c . Request . Context ( ) , account . ID , selection . WaitPlan . MaxWaiting )
if err != nil {
log . Printf ( "Increment account wait count failed: %v" , err )
} else if ! canWait {
log . Printf ( "Account wait queue full: account=%d" , account . ID )
h . handleStreamingAwareError ( c , http . StatusTooManyRequests , "rate_limit_error" , "Too many pending requests, please retry later" , streamStarted )
return
2026-01-09 20:56:37 +08:00
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func ( ) {
if accountWaitCounted {
2026-01-01 04:30:42 +08:00
h . concurrencyHelper . DecrementAccountWaitCount ( c . Request . Context ( ) , account . ID )
}
2026-01-09 20:56:37 +08:00
} ( )
2026-01-01 04:01:51 +08:00
accountReleaseFunc , err = h . concurrencyHelper . AcquireAccountSlotWithWaitTimeout (
c ,
account . ID ,
selection . WaitPlan . MaxConcurrency ,
selection . WaitPlan . Timeout ,
reqStream ,
& streamStarted ,
)
if err != nil {
log . Printf ( "Account concurrency acquire failed: %v" , err )
h . handleConcurrencyError ( c , err , "account" , streamStarted )
return
}
2026-01-09 20:56:37 +08:00
if accountWaitCounted {
h . concurrencyHelper . DecrementAccountWaitCount ( c . Request . Context ( ) , account . ID )
accountWaitCounted = false
}
2026-01-11 10:59:01 +08:00
if err := h . gatewayService . BindStickySession ( c . Request . Context ( ) , apiKey . GroupID , sessionHash , account . ID ) ; err != nil {
2026-01-01 04:01:51 +08:00
log . Printf ( "Bind sticky session failed: %v" , err )
}
2025-12-27 11:44:00 +08:00
}
2026-01-04 19:49:59 +08:00
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone ( c . Request . Context ( ) , accountReleaseFunc )
2025-12-22 22:58:31 +08:00
2025-12-27 11:44:00 +08:00
// Forward request
result , err := h . gatewayService . Forward ( c . Request . Context ( ) , c , account , body )
if accountReleaseFunc != nil {
accountReleaseFunc ( )
}
if err != nil {
var failoverErr * service . UpstreamFailoverError
if errors . As ( err , & failoverErr ) {
failedAccountIDs [ account . ID ] = struct { } { }
2026-02-05 21:52:54 +08:00
lastFailoverErr = failoverErr
2025-12-27 11:44:00 +08:00
if switchCount >= maxAccountSwitches {
2026-02-05 21:52:54 +08:00
h . handleFailoverExhausted ( c , failoverErr , streamStarted )
2025-12-27 11:44:00 +08:00
return
}
switchCount ++
log . Printf ( "Account %d: upstream error %d, switching account %d/%d" , account . ID , failoverErr . StatusCode , switchCount , maxAccountSwitches )
continue
}
// Error response already handled in Forward, just log
2026-01-04 16:45:11 +08:00
log . Printf ( "Account %d: Forward request failed: %v" , account . ID , err )
2025-12-27 11:44:00 +08:00
return
2025-12-22 22:58:31 +08:00
}
2025-12-27 11:44:00 +08:00
2026-01-12 15:19:40 +08:00
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c . GetHeader ( "User-Agent" )
2026-01-12 20:44:38 +08:00
clientIP := ip . GetClientIP ( c )
2026-01-12 15:19:40 +08:00
2025-12-27 11:44:00 +08:00
// Async record usage
2026-01-12 15:19:40 +08:00
go func ( result * service . OpenAIForwardResult , usedAccount * service . Account , ua , ip string ) {
2025-12-27 11:44:00 +08:00
ctx , cancel := context . WithTimeout ( context . Background ( ) , 10 * time . Second )
defer cancel ( )
if err := h . gatewayService . RecordUsage ( ctx , & service . OpenAIRecordUsageInput {
2026-02-03 19:01:49 +08:00
Result : result ,
APIKey : apiKey ,
User : apiKey . User ,
Account : usedAccount ,
Subscription : subscription ,
UserAgent : ua ,
IPAddress : ip ,
APIKeyService : h . apiKeyService ,
2025-12-27 11:44:00 +08:00
} ) ; err != nil {
log . Printf ( "Record usage failed: %v" , err )
}
2026-01-12 15:19:40 +08:00
} ( result , account , userAgent , clientIP )
2025-12-27 11:44:00 +08:00
return
}
2025-12-22 22:58:31 +08:00
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func ( h * OpenAIGatewayHandler ) handleConcurrencyError ( c * gin . Context , err error , slotType string , streamStarted bool ) {
h . handleStreamingAwareError ( c , http . StatusTooManyRequests , "rate_limit_error" ,
fmt . Sprintf ( "Concurrency limit exceeded for %s, please retry later" , slotType ) , streamStarted )
}
2026-02-05 21:52:54 +08:00
func ( h * OpenAIGatewayHandler ) handleFailoverExhausted ( c * gin . Context , failoverErr * service . UpstreamFailoverError , streamStarted bool ) {
statusCode := failoverErr . StatusCode
responseBody := failoverErr . ResponseBody
// 先检查透传规则
if h . errorPassthroughService != nil && len ( responseBody ) > 0 {
if rule := h . errorPassthroughService . MatchRule ( "openai" , statusCode , responseBody ) ; rule != nil {
// 确定响应状态码
respCode := statusCode
if ! rule . PassthroughCode && rule . ResponseCode != nil {
respCode = * rule . ResponseCode
}
// 确定响应消息
msg := service . ExtractUpstreamErrorMessage ( responseBody )
if ! rule . PassthroughBody && rule . CustomMessage != nil {
msg = * rule . CustomMessage
}
h . handleStreamingAwareError ( c , respCode , "upstream_error" , msg , streamStarted )
return
}
}
// 使用默认的错误映射
status , errType , errMsg := h . mapUpstreamError ( statusCode )
h . handleStreamingAwareError ( c , status , errType , errMsg , streamStarted )
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func ( h * OpenAIGatewayHandler ) handleFailoverExhaustedSimple ( c * gin . Context , statusCode int , streamStarted bool ) {
2025-12-27 11:44:00 +08:00
status , errType , errMsg := h . mapUpstreamError ( statusCode )
h . handleStreamingAwareError ( c , status , errType , errMsg , streamStarted )
}
func ( h * OpenAIGatewayHandler ) mapUpstreamError ( statusCode int ) ( int , string , string ) {
switch statusCode {
case 401 :
return http . StatusBadGateway , "upstream_error" , "Upstream authentication failed, please contact administrator"
case 403 :
return http . StatusBadGateway , "upstream_error" , "Upstream access forbidden, please contact administrator"
case 429 :
return http . StatusTooManyRequests , "rate_limit_error" , "Upstream rate limit exceeded, please retry later"
case 529 :
return http . StatusServiceUnavailable , "upstream_error" , "Upstream service overloaded, please retry later"
case 500 , 502 , 503 , 504 :
return http . StatusBadGateway , "upstream_error" , "Upstream service temporarily unavailable"
default :
return http . StatusBadGateway , "upstream_error" , "Upstream request failed"
}
}
2025-12-22 22:58:31 +08:00
// handleStreamingAwareError handles errors that may occur after streaming has started
func ( h * OpenAIGatewayHandler ) handleStreamingAwareError ( c * gin . Context , status int , errType , message string , streamStarted bool ) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher , ok := c . Writer . ( http . Flusher )
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt . Sprintf ( ` event: error ` + "\n" + ` data: { "error": { "type": "%s", "message": "%s"}} ` + "\n\n" , errType , message )
if _ , err := fmt . Fprint ( c . Writer , errorEvent ) ; err != nil {
_ = c . Error ( err )
}
flusher . Flush ( )
}
return
}
// Normal case: return JSON response with proper status code
h . errorResponse ( c , status , errType , message )
}
// errorResponse returns OpenAI API format error response
func ( h * OpenAIGatewayHandler ) errorResponse ( c * gin . Context , status int , errType , message string ) {
c . JSON ( status , gin . H {
"error" : gin . H {
"type" : errType ,
"message" : message ,
} ,
} )
}