2025-12-22 22:58:31 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"bufio"
|
|
|
|
|
|
"bytes"
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"crypto/sha256"
|
|
|
|
|
|
"encoding/hex"
|
|
|
|
|
|
"encoding/json"
|
|
|
|
|
|
"errors"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"io"
|
2025-12-29 03:17:25 +08:00
|
|
|
|
"log"
|
2025-12-22 22:58:31 +08:00
|
|
|
|
"net/http"
|
2025-12-26 03:49:55 -08:00
|
|
|
|
"regexp"
|
2026-01-01 04:01:51 +08:00
|
|
|
|
"sort"
|
2025-12-23 16:26:07 +08:00
|
|
|
|
"strconv"
|
2025-12-22 22:58:31 +08:00
|
|
|
|
"strings"
|
2026-01-04 20:19:07 +08:00
|
|
|
|
"sync/atomic"
|
2025-12-22 22:58:31 +08:00
|
|
|
|
"time"
|
|
|
|
|
|
|
2025-12-24 21:07:21 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
2026-01-10 03:12:56 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
2026-01-02 17:40:57 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
2025-12-22 22:58:31 +08:00
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
|
// ChatGPT internal API for OAuth accounts
|
|
|
|
|
|
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
|
|
|
|
|
|
// OpenAI Platform API for API Key accounts (fallback)
|
|
|
|
|
|
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
|
|
|
|
|
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-26 03:49:55 -08:00
|
|
|
|
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
|
|
|
|
|
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
|
|
|
|
|
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
|
|
|
|
|
var openaiAllowedHeaders = map[string]bool{
|
|
|
|
|
|
"accept-language": true,
|
|
|
|
|
|
"content-type": true,
|
2026-01-12 20:18:53 -08:00
|
|
|
|
"conversation_id": true,
|
2025-12-22 22:58:31 +08:00
|
|
|
|
"user-agent": true,
|
|
|
|
|
|
"originator": true,
|
|
|
|
|
|
"session_id": true,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-23 16:26:07 +08:00
|
|
|
|
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
|
|
|
|
|
|
type OpenAICodexUsageSnapshot struct {
|
|
|
|
|
|
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
|
|
|
|
|
|
PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"`
|
|
|
|
|
|
PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"`
|
|
|
|
|
|
SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"`
|
|
|
|
|
|
SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"`
|
|
|
|
|
|
SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"`
|
|
|
|
|
|
PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"`
|
|
|
|
|
|
UpdatedAt string `json:"updated_at,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// OpenAIUsage represents OpenAI API response usage
|
|
|
|
|
|
type OpenAIUsage struct {
|
|
|
|
|
|
InputTokens int `json:"input_tokens"`
|
|
|
|
|
|
OutputTokens int `json:"output_tokens"`
|
|
|
|
|
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
|
|
|
|
|
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OpenAIForwardResult represents the result of forwarding
|
|
|
|
|
|
type OpenAIForwardResult struct {
|
|
|
|
|
|
RequestID string
|
|
|
|
|
|
Usage OpenAIUsage
|
|
|
|
|
|
Model string
|
|
|
|
|
|
Stream bool
|
|
|
|
|
|
Duration time.Duration
|
|
|
|
|
|
FirstTokenMs *int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
|
|
|
|
|
type OpenAIGatewayService struct {
|
2026-01-15 19:42:18 +08:00
|
|
|
|
accountRepo AccountRepository
|
|
|
|
|
|
usageLogRepo UsageLogRepository
|
|
|
|
|
|
userRepo UserRepository
|
|
|
|
|
|
userSubRepo UserSubscriptionRepository
|
|
|
|
|
|
cache GatewayCache
|
|
|
|
|
|
cfg *config.Config
|
|
|
|
|
|
schedulerSnapshot *SchedulerSnapshotService
|
|
|
|
|
|
concurrencyService *ConcurrencyService
|
|
|
|
|
|
billingService *BillingService
|
|
|
|
|
|
rateLimitService *RateLimitService
|
|
|
|
|
|
billingCacheService *BillingCacheService
|
|
|
|
|
|
httpUpstream HTTPUpstream
|
|
|
|
|
|
deferredService *DeferredService
|
|
|
|
|
|
openAITokenProvider *OpenAITokenProvider
|
2026-01-15 23:52:50 +08:00
|
|
|
|
toolCorrector *CodexToolCorrector
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
|
|
|
|
|
func NewOpenAIGatewayService(
|
2025-12-25 17:15:01 +08:00
|
|
|
|
accountRepo AccountRepository,
|
|
|
|
|
|
usageLogRepo UsageLogRepository,
|
|
|
|
|
|
userRepo UserRepository,
|
|
|
|
|
|
userSubRepo UserSubscriptionRepository,
|
|
|
|
|
|
cache GatewayCache,
|
2025-12-22 22:58:31 +08:00
|
|
|
|
cfg *config.Config,
|
2026-01-12 14:19:06 +08:00
|
|
|
|
schedulerSnapshot *SchedulerSnapshotService,
|
2026-01-01 04:01:51 +08:00
|
|
|
|
concurrencyService *ConcurrencyService,
|
2025-12-22 22:58:31 +08:00
|
|
|
|
billingService *BillingService,
|
|
|
|
|
|
rateLimitService *RateLimitService,
|
|
|
|
|
|
billingCacheService *BillingCacheService,
|
2025-12-25 17:15:01 +08:00
|
|
|
|
httpUpstream HTTPUpstream,
|
2025-12-28 08:07:15 +08:00
|
|
|
|
deferredService *DeferredService,
|
2026-01-15 18:27:06 +08:00
|
|
|
|
openAITokenProvider *OpenAITokenProvider,
|
2025-12-22 22:58:31 +08:00
|
|
|
|
) *OpenAIGatewayService {
|
|
|
|
|
|
return &OpenAIGatewayService{
|
2026-01-15 19:42:18 +08:00
|
|
|
|
accountRepo: accountRepo,
|
|
|
|
|
|
usageLogRepo: usageLogRepo,
|
|
|
|
|
|
userRepo: userRepo,
|
|
|
|
|
|
userSubRepo: userSubRepo,
|
|
|
|
|
|
cache: cache,
|
|
|
|
|
|
cfg: cfg,
|
|
|
|
|
|
schedulerSnapshot: schedulerSnapshot,
|
|
|
|
|
|
concurrencyService: concurrencyService,
|
|
|
|
|
|
billingService: billingService,
|
|
|
|
|
|
rateLimitService: rateLimitService,
|
|
|
|
|
|
billingCacheService: billingCacheService,
|
|
|
|
|
|
httpUpstream: httpUpstream,
|
|
|
|
|
|
deferredService: deferredService,
|
|
|
|
|
|
openAITokenProvider: openAITokenProvider,
|
2026-01-15 23:52:50 +08:00
|
|
|
|
toolCorrector: NewCodexToolCorrector(),
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
|
|
|
|
|
|
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
|
|
|
|
|
sessionID := c.GetHeader("session_id")
|
|
|
|
|
|
if sessionID == "" {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
hash := sha256.Sum256([]byte(sessionID))
|
|
|
|
|
|
return hex.EncodeToString(hash[:])
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-01 04:01:51 +08:00
|
|
|
|
// BindStickySession sets session -> account binding with standard TTL.
|
2026-01-08 23:07:00 +08:00
|
|
|
|
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
2026-01-01 04:01:51 +08:00
|
|
|
|
if sessionHash == "" || accountID <= 0 {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2026-01-08 23:07:00 +08:00
|
|
|
|
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// SelectAccount selects an OpenAI account with sticky session support
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// SelectAccountForModel selects an account supporting the requested model
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
|
|
|
|
|
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// 1. Check sticky session
|
|
|
|
|
|
if sessionHash != "" {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if err == nil && accountID > 0 {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
2026-01-12 14:19:06 +08:00
|
|
|
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
|
|
|
|
|
// Refresh sticky session TTL
|
2026-01-08 23:07:00 +08:00
|
|
|
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
2025-12-27 11:44:00 +08:00
|
|
|
|
return account, nil
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. Get schedulable OpenAI accounts
|
2026-01-12 14:19:06 +08:00
|
|
|
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. Select by priority + LRU
|
2025-12-26 15:40:24 +08:00
|
|
|
|
var selected *Account
|
2025-12-22 22:58:31 +08:00
|
|
|
|
for i := range accounts {
|
|
|
|
|
|
acc := &accounts[i]
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-01-13 22:49:26 -08:00
|
|
|
|
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
|
|
|
|
|
// avoid selecting accounts that were recently rate-limited/overloaded.
|
|
|
|
|
|
if !acc.IsSchedulable() {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Check model support
|
|
|
|
|
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if selected == nil {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
// Lower priority value means higher priority
|
|
|
|
|
|
if acc.Priority < selected.Priority {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
} else if acc.Priority == selected.Priority {
|
2025-12-25 21:24:44 -08:00
|
|
|
|
switch {
|
|
|
|
|
|
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
2025-12-22 22:58:31 +08:00
|
|
|
|
selected = acc
|
2025-12-25 21:24:44 -08:00
|
|
|
|
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
|
|
|
|
|
// keep selected (never used is preferred)
|
|
|
|
|
|
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
|
|
|
|
|
// keep selected (both never used)
|
|
|
|
|
|
default:
|
|
|
|
|
|
// Same priority, select least recently used
|
|
|
|
|
|
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
|
|
|
|
|
selected = acc
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if selected == nil {
|
|
|
|
|
|
if requestedModel != "" {
|
|
|
|
|
|
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, errors.New("no available OpenAI accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 4. Set sticky session
|
|
|
|
|
|
if sessionHash != "" {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return selected, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-01 04:01:51 +08:00
|
|
|
|
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
|
|
|
|
|
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
|
|
|
|
|
cfg := s.schedulingConfig()
|
|
|
|
|
|
var stickyAccountID int64
|
|
|
|
|
|
if sessionHash != "" && s.cache != nil {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
|
2026-01-01 04:01:51 +08:00
|
|
|
|
stickyAccountID = accountID
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
|
|
|
|
|
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
|
|
|
|
|
if err == nil && result.Acquired {
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: account,
|
|
|
|
|
|
Acquired: true,
|
|
|
|
|
|
ReleaseFunc: result.ReleaseFunc,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
|
|
|
|
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
|
|
|
|
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: account,
|
|
|
|
|
|
WaitPlan: &AccountWaitPlan{
|
|
|
|
|
|
AccountID: account.ID,
|
|
|
|
|
|
MaxConcurrency: account.Concurrency,
|
|
|
|
|
|
Timeout: cfg.StickySessionWaitTimeout,
|
|
|
|
|
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
|
|
|
|
|
},
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: account,
|
|
|
|
|
|
WaitPlan: &AccountWaitPlan{
|
|
|
|
|
|
AccountID: account.ID,
|
|
|
|
|
|
MaxConcurrency: account.Concurrency,
|
|
|
|
|
|
Timeout: cfg.FallbackWaitTimeout,
|
|
|
|
|
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
|
|
|
|
|
},
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(accounts) == 0 {
|
|
|
|
|
|
return nil, errors.New("no available accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
isExcluded := func(accountID int64) bool {
|
|
|
|
|
|
if excludedIDs == nil {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
_, excluded := excludedIDs[accountID]
|
|
|
|
|
|
return excluded
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ============ Layer 1: Sticky session ============
|
|
|
|
|
|
if sessionHash != "" {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
2026-01-12 14:19:06 +08:00
|
|
|
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
|
|
|
|
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
|
|
|
|
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
|
|
|
|
|
if err == nil && result.Acquired {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: account,
|
|
|
|
|
|
Acquired: true,
|
|
|
|
|
|
ReleaseFunc: result.ReleaseFunc,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
|
|
|
|
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: account,
|
|
|
|
|
|
WaitPlan: &AccountWaitPlan{
|
|
|
|
|
|
AccountID: accountID,
|
|
|
|
|
|
MaxConcurrency: account.Concurrency,
|
|
|
|
|
|
Timeout: cfg.StickySessionWaitTimeout,
|
|
|
|
|
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
|
|
|
|
|
},
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ============ Layer 2: Load-aware selection ============
|
|
|
|
|
|
candidates := make([]*Account, 0, len(accounts))
|
|
|
|
|
|
for i := range accounts {
|
|
|
|
|
|
acc := &accounts[i]
|
|
|
|
|
|
if isExcluded(acc.ID) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-01-13 22:49:26 -08:00
|
|
|
|
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
|
|
|
|
|
// re-check schedulability here so recently rate-limited/overloaded accounts
|
|
|
|
|
|
// are not selected again before the bucket is rebuilt.
|
|
|
|
|
|
if !acc.IsSchedulable() {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-01-01 04:01:51 +08:00
|
|
|
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
candidates = append(candidates, acc)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if len(candidates) == 0 {
|
|
|
|
|
|
return nil, errors.New("no available accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
|
|
|
|
|
for _, acc := range candidates {
|
|
|
|
|
|
accountLoads = append(accountLoads, AccountWithConcurrency{
|
|
|
|
|
|
ID: acc.ID,
|
|
|
|
|
|
MaxConcurrency: acc.Concurrency,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
ordered := append([]*Account(nil), candidates...)
|
|
|
|
|
|
sortAccountsByPriorityAndLastUsed(ordered, false)
|
|
|
|
|
|
for _, acc := range ordered {
|
|
|
|
|
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
|
|
|
|
|
if err == nil && result.Acquired {
|
|
|
|
|
|
if sessionHash != "" {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
}
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: acc,
|
|
|
|
|
|
Acquired: true,
|
|
|
|
|
|
ReleaseFunc: result.ReleaseFunc,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
type accountWithLoad struct {
|
|
|
|
|
|
account *Account
|
|
|
|
|
|
loadInfo *AccountLoadInfo
|
|
|
|
|
|
}
|
|
|
|
|
|
var available []accountWithLoad
|
|
|
|
|
|
for _, acc := range candidates {
|
|
|
|
|
|
loadInfo := loadMap[acc.ID]
|
|
|
|
|
|
if loadInfo == nil {
|
|
|
|
|
|
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
|
|
|
|
|
}
|
|
|
|
|
|
if loadInfo.LoadRate < 100 {
|
|
|
|
|
|
available = append(available, accountWithLoad{
|
|
|
|
|
|
account: acc,
|
|
|
|
|
|
loadInfo: loadInfo,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if len(available) > 0 {
|
|
|
|
|
|
sort.SliceStable(available, func(i, j int) bool {
|
|
|
|
|
|
a, b := available[i], available[j]
|
|
|
|
|
|
if a.account.Priority != b.account.Priority {
|
|
|
|
|
|
return a.account.Priority < b.account.Priority
|
|
|
|
|
|
}
|
|
|
|
|
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
|
|
|
|
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
|
|
|
|
|
}
|
|
|
|
|
|
switch {
|
|
|
|
|
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
|
|
|
|
|
return true
|
|
|
|
|
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
|
|
|
|
|
return false
|
|
|
|
|
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
|
|
|
|
|
return false
|
|
|
|
|
|
default:
|
|
|
|
|
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
for _, item := range available {
|
|
|
|
|
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
|
|
|
|
|
if err == nil && result.Acquired {
|
|
|
|
|
|
if sessionHash != "" {
|
2026-01-08 23:07:00 +08:00
|
|
|
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
}
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: item.account,
|
|
|
|
|
|
Acquired: true,
|
|
|
|
|
|
ReleaseFunc: result.ReleaseFunc,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ============ Layer 3: Fallback wait ============
|
|
|
|
|
|
sortAccountsByPriorityAndLastUsed(candidates, false)
|
|
|
|
|
|
for _, acc := range candidates {
|
|
|
|
|
|
return &AccountSelectionResult{
|
|
|
|
|
|
Account: acc,
|
|
|
|
|
|
WaitPlan: &AccountWaitPlan{
|
|
|
|
|
|
AccountID: acc.ID,
|
|
|
|
|
|
MaxConcurrency: acc.Concurrency,
|
|
|
|
|
|
Timeout: cfg.FallbackWaitTimeout,
|
|
|
|
|
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
|
|
|
|
|
},
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nil, errors.New("no available accounts")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
2026-01-12 14:19:06 +08:00
|
|
|
|
if s.schedulerSnapshot != nil {
|
|
|
|
|
|
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
|
|
|
|
|
return accounts, err
|
|
|
|
|
|
}
|
2026-01-01 04:01:51 +08:00
|
|
|
|
var accounts []Account
|
|
|
|
|
|
var err error
|
|
|
|
|
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|
|
|
|
|
} else if groupID != nil {
|
|
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|
|
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return accounts, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
|
|
|
|
|
if s.concurrencyService == nil {
|
|
|
|
|
|
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-12 14:19:06 +08:00
|
|
|
|
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
|
|
|
|
|
if s.schedulerSnapshot != nil {
|
|
|
|
|
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
|
|
|
|
|
}
|
|
|
|
|
|
return s.accountRepo.GetByID(ctx, accountID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-01 04:01:51 +08:00
|
|
|
|
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
|
|
|
|
|
if s.cfg != nil {
|
|
|
|
|
|
return s.cfg.Gateway.Scheduling
|
|
|
|
|
|
}
|
|
|
|
|
|
return config.GatewaySchedulingConfig{
|
|
|
|
|
|
StickySessionMaxWaiting: 3,
|
|
|
|
|
|
StickySessionWaitTimeout: 45 * time.Second,
|
|
|
|
|
|
FallbackWaitTimeout: 30 * time.Second,
|
|
|
|
|
|
FallbackMaxWaiting: 100,
|
|
|
|
|
|
LoadBatchEnabled: true,
|
|
|
|
|
|
SlotCleanupInterval: 30 * time.Second,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// GetAccessToken gets the access token for an OpenAI account
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
2025-12-23 10:25:32 +08:00
|
|
|
|
switch account.Type {
|
2025-12-26 15:40:24 +08:00
|
|
|
|
case AccountTypeOAuth:
|
2026-01-15 18:27:06 +08:00
|
|
|
|
// 使用 TokenProvider 获取缓存的 token
|
|
|
|
|
|
if s.openAITokenProvider != nil {
|
|
|
|
|
|
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", "", err
|
|
|
|
|
|
}
|
|
|
|
|
|
return accessToken, "oauth", nil
|
|
|
|
|
|
}
|
|
|
|
|
|
// 降级:TokenProvider 未配置时直接从账号读取
|
2025-12-22 22:58:31 +08:00
|
|
|
|
accessToken := account.GetOpenAIAccessToken()
|
|
|
|
|
|
if accessToken == "" {
|
|
|
|
|
|
return "", "", errors.New("access_token not found in credentials")
|
|
|
|
|
|
}
|
|
|
|
|
|
return accessToken, "oauth", nil
|
2026-01-04 19:27:53 +08:00
|
|
|
|
case AccountTypeAPIKey:
|
2025-12-22 22:58:31 +08:00
|
|
|
|
apiKey := account.GetOpenAIApiKey()
|
|
|
|
|
|
if apiKey == "" {
|
|
|
|
|
|
return "", "", errors.New("api_key not found in credentials")
|
|
|
|
|
|
}
|
|
|
|
|
|
return apiKey, "apikey", nil
|
2025-12-23 10:25:32 +08:00
|
|
|
|
default:
|
|
|
|
|
|
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
|
|
|
|
|
switch statusCode {
|
2025-12-31 11:46:53 +08:00
|
|
|
|
case 401, 402, 403, 429, 529:
|
2025-12-27 11:44:00 +08:00
|
|
|
|
return true
|
|
|
|
|
|
default:
|
|
|
|
|
|
return statusCode >= 500
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
2026-01-11 15:30:27 +08:00
|
|
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
2025-12-27 11:44:00 +08:00
|
|
|
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Forward forwards request to OpenAI API
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
startTime := time.Now()
|
|
|
|
|
|
|
|
|
|
|
|
// Parse request body once (avoid multiple parse/serialize cycles)
|
|
|
|
|
|
var reqBody map[string]any
|
|
|
|
|
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("parse request: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Extract model and stream from parsed body
|
|
|
|
|
|
reqModel, _ := reqBody["model"].(string)
|
|
|
|
|
|
reqStream, _ := reqBody["stream"].(bool)
|
2026-01-09 18:35:58 +08:00
|
|
|
|
promptCacheKey := ""
|
2026-01-10 03:12:56 +08:00
|
|
|
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
|
|
|
|
|
promptCacheKey = strings.TrimSpace(v)
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
// Track if body needs re-serialization
|
|
|
|
|
|
bodyModified := false
|
|
|
|
|
|
originalModel := reqModel
|
|
|
|
|
|
|
2026-01-10 03:12:56 +08:00
|
|
|
|
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
|
|
|
|
|
|
2026-01-13 17:01:21 +08:00
|
|
|
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
2026-01-12 13:23:05 -08:00
|
|
|
|
mappedModel := account.GetMappedModel(reqModel)
|
|
|
|
|
|
if mappedModel != reqModel {
|
|
|
|
|
|
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
|
|
|
|
|
reqBody["model"] = mappedModel
|
|
|
|
|
|
bodyModified = true
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-13 17:01:21 +08:00
|
|
|
|
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
2026-01-12 20:18:53 -08:00
|
|
|
|
if model, ok := reqBody["model"].(string); ok {
|
|
|
|
|
|
normalizedModel := normalizeCodexModel(model)
|
|
|
|
|
|
if normalizedModel != "" && normalizedModel != model {
|
|
|
|
|
|
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
|
|
|
|
|
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
|
|
|
|
|
reqBody["model"] = normalizedModel
|
|
|
|
|
|
mappedModel = normalizedModel
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-13 17:01:21 +08:00
|
|
|
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
2026-01-12 20:18:53 -08:00
|
|
|
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
|
|
|
|
|
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
|
|
|
|
|
reasoning["effort"] = "none"
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-10 03:12:56 +08:00
|
|
|
|
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
|
|
|
|
|
codexResult := applyCodexOAuthTransform(reqBody)
|
2026-01-09 18:35:58 +08:00
|
|
|
|
if codexResult.Modified {
|
2026-01-09 00:34:49 +08:00
|
|
|
|
bodyModified = true
|
|
|
|
|
|
}
|
2026-01-09 18:35:58 +08:00
|
|
|
|
if codexResult.NormalizedModel != "" {
|
|
|
|
|
|
mappedModel = codexResult.NormalizedModel
|
|
|
|
|
|
}
|
|
|
|
|
|
if codexResult.PromptCacheKey != "" {
|
|
|
|
|
|
promptCacheKey = codexResult.PromptCacheKey
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-12 11:08:28 -08:00
|
|
|
|
// Handle max_output_tokens based on platform and account type
|
|
|
|
|
|
if !isCodexCLI {
|
|
|
|
|
|
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
|
|
|
|
|
|
switch account.Platform {
|
|
|
|
|
|
case PlatformOpenAI:
|
|
|
|
|
|
// For OpenAI API Key, remove max_output_tokens (not supported)
|
|
|
|
|
|
// For OpenAI OAuth (Responses API), keep it (supported)
|
|
|
|
|
|
if account.Type == AccountTypeAPIKey {
|
|
|
|
|
|
delete(reqBody, "max_output_tokens")
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
}
|
|
|
|
|
|
case PlatformAnthropic:
|
|
|
|
|
|
// For Anthropic (Claude), convert to max_tokens
|
|
|
|
|
|
delete(reqBody, "max_output_tokens")
|
|
|
|
|
|
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
|
|
|
|
|
|
reqBody["max_tokens"] = maxOutputTokens
|
|
|
|
|
|
}
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
case PlatformGemini:
|
|
|
|
|
|
// For Gemini, remove (will be handled by Gemini-specific transform)
|
|
|
|
|
|
delete(reqBody, "max_output_tokens")
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
default:
|
|
|
|
|
|
// For unknown platforms, remove to be safe
|
|
|
|
|
|
delete(reqBody, "max_output_tokens")
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Also handle max_completion_tokens (similar logic)
|
|
|
|
|
|
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
|
|
|
|
|
|
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
|
|
|
|
|
|
delete(reqBody, "max_completion_tokens")
|
|
|
|
|
|
bodyModified = true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Re-serialize body only if modified
|
|
|
|
|
|
if bodyModified {
|
|
|
|
|
|
var err error
|
|
|
|
|
|
body, err = json.Marshal(reqBody)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("serialize request body: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Get access token
|
|
|
|
|
|
token, _, err := s.GetAccessToken(ctx, account)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Build upstream request
|
2026-01-10 20:53:16 +08:00
|
|
|
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Get proxy URL
|
|
|
|
|
|
proxyURL := ""
|
|
|
|
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|
|
|
|
|
proxyURL = account.Proxy.URL()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 15:14:44 +08:00
|
|
|
|
// Capture upstream request body for ops retry of this attempt.
|
|
|
|
|
|
if c != nil {
|
|
|
|
|
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Send request
|
2025-12-31 11:43:58 +08:00
|
|
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if err != nil {
|
2026-01-11 11:49:34 +08:00
|
|
|
|
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
|
|
|
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
|
|
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
2026-01-11 15:30:27 +08:00
|
|
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
|
|
|
|
Platform: account.Platform,
|
|
|
|
|
|
AccountID: account.ID,
|
2026-01-15 15:14:44 +08:00
|
|
|
|
AccountName: account.Name,
|
2026-01-11 15:30:27 +08:00
|
|
|
|
UpstreamStatusCode: 0,
|
|
|
|
|
|
Kind: "request_error",
|
|
|
|
|
|
Message: safeErr,
|
|
|
|
|
|
})
|
2026-01-11 11:49:34 +08:00
|
|
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|
|
|
|
|
"error": gin.H{
|
|
|
|
|
|
"type": "upstream_error",
|
|
|
|
|
|
"message": "Upstream request failed",
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
|
|
|
|
|
|
|
|
// Handle error response
|
|
|
|
|
|
if resp.StatusCode >= 400 {
|
2025-12-27 11:44:00 +08:00
|
|
|
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
2026-01-11 15:30:27 +08:00
|
|
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
|
|
|
|
_ = resp.Body.Close()
|
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
|
|
|
|
|
|
|
|
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
|
|
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
|
|
|
|
upstreamDetail := ""
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
|
|
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
|
|
|
|
if maxBytes <= 0 {
|
|
|
|
|
|
maxBytes = 2048
|
|
|
|
|
|
}
|
|
|
|
|
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|
|
|
|
|
}
|
|
|
|
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
|
|
|
|
Platform: account.Platform,
|
|
|
|
|
|
AccountID: account.ID,
|
2026-01-15 15:14:44 +08:00
|
|
|
|
AccountName: account.Name,
|
2026-01-11 15:30:27 +08:00
|
|
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
|
|
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
|
|
|
|
Kind: "failover",
|
|
|
|
|
|
Message: upstreamMsg,
|
|
|
|
|
|
Detail: upstreamDetail,
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2025-12-27 11:44:00 +08:00
|
|
|
|
s.handleFailoverSideEffects(ctx, resp, account)
|
|
|
|
|
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
return s.handleErrorResponse(ctx, resp, c, account)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Handle normal response
|
|
|
|
|
|
var usage *OpenAIUsage
|
|
|
|
|
|
var firstTokenMs *int
|
|
|
|
|
|
if reqStream {
|
|
|
|
|
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
usage = streamResult.usage
|
|
|
|
|
|
firstTokenMs = streamResult.firstTokenMs
|
|
|
|
|
|
} else {
|
|
|
|
|
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-23 16:26:07 +08:00
|
|
|
|
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
if account.Type == AccountTypeOAuth {
|
2025-12-23 16:26:07 +08:00
|
|
|
|
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
|
|
|
|
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
return &OpenAIForwardResult{
|
|
|
|
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|
|
|
|
|
Usage: *usage,
|
|
|
|
|
|
Model: originalModel,
|
|
|
|
|
|
Stream: reqStream,
|
|
|
|
|
|
Duration: time.Since(startTime),
|
|
|
|
|
|
FirstTokenMs: firstTokenMs,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-10 20:53:16 +08:00
|
|
|
|
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Determine target URL based on account type
|
|
|
|
|
|
var targetURL string
|
2025-12-23 10:25:32 +08:00
|
|
|
|
switch account.Type {
|
2025-12-26 15:40:24 +08:00
|
|
|
|
case AccountTypeOAuth:
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// OAuth accounts use ChatGPT internal API
|
|
|
|
|
|
targetURL = chatgptCodexURL
|
2026-01-04 19:27:53 +08:00
|
|
|
|
case AccountTypeAPIKey:
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// API Key accounts use Platform API or custom base URL
|
|
|
|
|
|
baseURL := account.GetOpenAIBaseURL()
|
2026-01-02 17:40:57 +08:00
|
|
|
|
if baseURL == "" {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
targetURL = openaiPlatformAPIURL
|
2026-01-02 17:40:57 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
targetURL = validatedURL + "/responses"
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
2025-12-23 10:25:32 +08:00
|
|
|
|
default:
|
2025-12-22 22:58:31 +08:00
|
|
|
|
targetURL = openaiPlatformAPIURL
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Set authentication header
|
|
|
|
|
|
req.Header.Set("authorization", "Bearer "+token)
|
|
|
|
|
|
|
|
|
|
|
|
// Set headers specific to OAuth accounts (ChatGPT internal API)
|
2025-12-26 15:40:24 +08:00
|
|
|
|
if account.Type == AccountTypeOAuth {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
|
|
|
|
|
|
req.Host = "chatgpt.com"
|
|
|
|
|
|
// Required: set chatgpt-account-id header
|
|
|
|
|
|
chatgptAccountID := account.GetChatGPTAccountID()
|
|
|
|
|
|
if chatgptAccountID != "" {
|
|
|
|
|
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Whitelist passthrough headers
|
|
|
|
|
|
for key, values := range c.Request.Header {
|
|
|
|
|
|
lowerKey := strings.ToLower(key)
|
|
|
|
|
|
if openaiAllowedHeaders[lowerKey] {
|
|
|
|
|
|
for _, v := range values {
|
|
|
|
|
|
req.Header.Add(key, v)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-01-09 18:35:58 +08:00
|
|
|
|
if account.Type == AccountTypeOAuth {
|
|
|
|
|
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
2026-01-10 20:53:16 +08:00
|
|
|
|
if isCodexCLI {
|
|
|
|
|
|
req.Header.Set("originator", "codex_cli_rs")
|
|
|
|
|
|
} else {
|
|
|
|
|
|
req.Header.Set("originator", "opencode")
|
|
|
|
|
|
}
|
2026-01-09 18:35:58 +08:00
|
|
|
|
req.Header.Set("accept", "text/event-stream")
|
|
|
|
|
|
if promptCacheKey != "" {
|
|
|
|
|
|
req.Header.Set("conversation_id", promptCacheKey)
|
|
|
|
|
|
req.Header.Set("session_id", promptCacheKey)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
// Apply custom User-Agent if configured
|
|
|
|
|
|
customUA := account.GetOpenAIUserAgent()
|
|
|
|
|
|
if customUA != "" {
|
|
|
|
|
|
req.Header.Set("user-agent", customUA)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Ensure required headers exist
|
|
|
|
|
|
if req.Header.Get("content-type") == "" {
|
|
|
|
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return req, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
|
2026-01-11 11:49:34 +08:00
|
|
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
|
|
|
|
|
|
|
|
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
|
|
|
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
|
|
|
|
upstreamDetail := ""
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
|
|
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
|
|
|
|
if maxBytes <= 0 {
|
|
|
|
|
|
maxBytes = 2048
|
|
|
|
|
|
}
|
|
|
|
|
|
upstreamDetail = truncateString(string(body), maxBytes)
|
|
|
|
|
|
}
|
|
|
|
|
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|
|
|
|
|
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
|
|
|
|
log.Printf(
|
|
|
|
|
|
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
|
|
|
|
|
|
resp.StatusCode,
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
account.Platform,
|
|
|
|
|
|
account.Type,
|
|
|
|
|
|
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
// Check custom error codes
|
|
|
|
|
|
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
2026-01-11 15:30:27 +08:00
|
|
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
|
|
|
|
Platform: account.Platform,
|
|
|
|
|
|
AccountID: account.ID,
|
2026-01-15 15:14:44 +08:00
|
|
|
|
AccountName: account.Name,
|
2026-01-11 15:30:27 +08:00
|
|
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
|
|
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
|
|
|
|
Kind: "http_error",
|
|
|
|
|
|
Message: upstreamMsg,
|
|
|
|
|
|
Detail: upstreamDetail,
|
|
|
|
|
|
})
|
2025-12-22 22:58:31 +08:00
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
|
|
|
|
"error": gin.H{
|
|
|
|
|
|
"type": "upstream_error",
|
|
|
|
|
|
"message": "Upstream gateway error",
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
2026-01-11 11:49:34 +08:00
|
|
|
|
if upstreamMsg == "" {
|
|
|
|
|
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Handle upstream error (mark account status)
|
2026-01-03 06:37:08 -08:00
|
|
|
|
shouldDisable := false
|
|
|
|
|
|
if s.rateLimitService != nil {
|
|
|
|
|
|
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
|
|
|
|
}
|
2026-01-11 15:30:27 +08:00
|
|
|
|
kind := "http_error"
|
|
|
|
|
|
if shouldDisable {
|
|
|
|
|
|
kind = "failover"
|
|
|
|
|
|
}
|
|
|
|
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
|
|
|
|
Platform: account.Platform,
|
|
|
|
|
|
AccountID: account.ID,
|
2026-01-15 15:14:44 +08:00
|
|
|
|
AccountName: account.Name,
|
2026-01-11 15:30:27 +08:00
|
|
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
|
|
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
|
|
|
|
Kind: kind,
|
|
|
|
|
|
Message: upstreamMsg,
|
|
|
|
|
|
Detail: upstreamDetail,
|
|
|
|
|
|
})
|
2026-01-03 06:37:08 -08:00
|
|
|
|
if shouldDisable {
|
|
|
|
|
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
// Return appropriate error response
|
|
|
|
|
|
var errType, errMsg string
|
|
|
|
|
|
var statusCode int
|
|
|
|
|
|
|
|
|
|
|
|
switch resp.StatusCode {
|
|
|
|
|
|
case 401:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream authentication failed, please contact administrator"
|
2025-12-31 11:46:53 +08:00
|
|
|
|
case 402:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream payment required: insufficient balance or billing issue"
|
2025-12-22 22:58:31 +08:00
|
|
|
|
case 403:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream access forbidden, please contact administrator"
|
|
|
|
|
|
case 429:
|
|
|
|
|
|
statusCode = http.StatusTooManyRequests
|
|
|
|
|
|
errType = "rate_limit_error"
|
|
|
|
|
|
errMsg = "Upstream rate limit exceeded, please retry later"
|
|
|
|
|
|
default:
|
|
|
|
|
|
statusCode = http.StatusBadGateway
|
|
|
|
|
|
errType = "upstream_error"
|
|
|
|
|
|
errMsg = "Upstream request failed"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
c.JSON(statusCode, gin.H{
|
|
|
|
|
|
"error": gin.H{
|
|
|
|
|
|
"type": errType,
|
|
|
|
|
|
"message": errMsg,
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2026-01-11 11:49:34 +08:00
|
|
|
|
if upstreamMsg == "" {
|
|
|
|
|
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// openaiStreamingResult streaming response result
|
|
|
|
|
|
type openaiStreamingResult struct {
|
|
|
|
|
|
usage *OpenAIUsage
|
|
|
|
|
|
firstTokenMs *int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
2026-01-05 13:54:43 +08:00
|
|
|
|
if s.cfg != nil {
|
|
|
|
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Set SSE response headers
|
|
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
|
|
|
|
c.Header("Cache-Control", "no-cache")
|
|
|
|
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
|
|
c.Header("X-Accel-Buffering", "no")
|
|
|
|
|
|
|
|
|
|
|
|
// Pass through other headers
|
|
|
|
|
|
if v := resp.Header.Get("x-request-id"); v != "" {
|
|
|
|
|
|
c.Header("x-request-id", v)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
w := c.Writer
|
|
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return nil, errors.New("streaming not supported")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
usage := &OpenAIUsage{}
|
|
|
|
|
|
var firstTokenMs *int
|
|
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
2026-01-04 19:49:59 +08:00
|
|
|
|
maxLineSize := defaultMaxLineSize
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|
|
|
|
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
|
|
|
|
}
|
|
|
|
|
|
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
|
|
|
|
|
|
|
|
|
|
|
type scanEvent struct {
|
|
|
|
|
|
line string
|
|
|
|
|
|
err error
|
|
|
|
|
|
}
|
|
|
|
|
|
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
|
2026-01-04 20:19:07 +08:00
|
|
|
|
events := make(chan scanEvent, 16)
|
2026-01-04 19:49:59 +08:00
|
|
|
|
done := make(chan struct{})
|
|
|
|
|
|
sendEvent := func(ev scanEvent) bool {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case events <- ev:
|
|
|
|
|
|
return true
|
|
|
|
|
|
case <-done:
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-01-04 20:19:07 +08:00
|
|
|
|
var lastReadAt int64
|
|
|
|
|
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
2026-01-04 19:49:59 +08:00
|
|
|
|
go func() {
|
|
|
|
|
|
defer close(events)
|
|
|
|
|
|
for scanner.Scan() {
|
2026-01-04 20:19:07 +08:00
|
|
|
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
2026-01-04 19:49:59 +08:00
|
|
|
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
|
|
|
|
_ = sendEvent(scanEvent{err: err})
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
defer close(done)
|
|
|
|
|
|
|
|
|
|
|
|
streamInterval := time.Duration(0)
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
|
|
|
|
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
|
|
|
|
|
}
|
2026-01-04 20:19:07 +08:00
|
|
|
|
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
|
|
|
|
|
|
var intervalTicker *time.Ticker
|
2026-01-04 19:49:59 +08:00
|
|
|
|
if streamInterval > 0 {
|
2026-01-04 20:19:07 +08:00
|
|
|
|
intervalTicker = time.NewTicker(streamInterval)
|
|
|
|
|
|
defer intervalTicker.Stop()
|
2026-01-04 19:49:59 +08:00
|
|
|
|
}
|
|
|
|
|
|
var intervalCh <-chan time.Time
|
2026-01-04 20:19:07 +08:00
|
|
|
|
if intervalTicker != nil {
|
|
|
|
|
|
intervalCh = intervalTicker.C
|
2026-01-04 19:49:59 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
keepaliveInterval := time.Duration(0)
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
|
|
|
|
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
|
|
|
|
|
}
|
|
|
|
|
|
// 下游 keepalive 仅用于防止代理空闲断开
|
|
|
|
|
|
var keepaliveTicker *time.Ticker
|
|
|
|
|
|
if keepaliveInterval > 0 {
|
|
|
|
|
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
|
|
|
|
|
defer keepaliveTicker.Stop()
|
|
|
|
|
|
}
|
|
|
|
|
|
var keepaliveCh <-chan time.Time
|
|
|
|
|
|
if keepaliveTicker != nil {
|
|
|
|
|
|
keepaliveCh = keepaliveTicker.C
|
|
|
|
|
|
}
|
|
|
|
|
|
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
|
|
|
|
|
lastDataAt := time.Now()
|
|
|
|
|
|
|
|
|
|
|
|
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
|
|
|
|
|
errorEventSent := false
|
|
|
|
|
|
sendErrorEvent := func(reason string) {
|
|
|
|
|
|
if errorEventSent {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
errorEventSent = true
|
|
|
|
|
|
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
|
|
|
|
|
flusher.Flush()
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
needModelReplace := originalModel != mappedModel
|
|
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case ev, ok := <-events:
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
if ev.err != nil {
|
|
|
|
|
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
|
|
|
|
|
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
|
|
|
|
|
sendErrorEvent("response_too_large")
|
|
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
|
|
|
|
|
}
|
|
|
|
|
|
sendErrorEvent("stream_read_error")
|
|
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
|
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
line := ev.line
|
|
|
|
|
|
lastDataAt = time.Now()
|
2025-12-26 03:49:55 -08:00
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
|
|
|
|
|
if openaiSSEDataRe.MatchString(line) {
|
|
|
|
|
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
|
|
|
|
|
|
|
|
|
|
// Replace model in response if needed
|
|
|
|
|
|
if needModelReplace {
|
|
|
|
|
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 23:52:50 +08:00
|
|
|
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
|
|
|
|
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
|
|
|
|
|
line = "data: " + correctedData
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
// Forward line
|
|
|
|
|
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
|
|
|
|
sendErrorEvent("write_failed")
|
|
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
flusher.Flush()
|
|
|
|
|
|
|
|
|
|
|
|
// Record first token time
|
|
|
|
|
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
|
|
|
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|
|
|
|
|
firstTokenMs = &ms
|
|
|
|
|
|
}
|
|
|
|
|
|
s.parseSSEUsage(data, usage)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// Forward non-data lines as-is
|
|
|
|
|
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
|
|
|
|
sendErrorEvent("write_failed")
|
|
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
flusher.Flush()
|
2025-12-26 03:49:55 -08:00
|
|
|
|
}
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
2026-01-04 19:49:59 +08:00
|
|
|
|
case <-intervalCh:
|
2026-01-04 20:19:07 +08:00
|
|
|
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
|
|
|
|
|
if time.Since(lastRead) < streamInterval {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-01-04 19:49:59 +08:00
|
|
|
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
2026-01-11 21:54:52 -08:00
|
|
|
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
|
|
|
|
|
if s.rateLimitService != nil {
|
|
|
|
|
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
|
|
|
|
|
}
|
2026-01-04 19:49:59 +08:00
|
|
|
|
sendErrorEvent("stream_timeout")
|
|
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
|
|
|
|
|
|
|
|
|
|
|
case <-keepaliveCh:
|
|
|
|
|
|
if time.Since(lastDataAt) < keepaliveInterval {
|
|
|
|
|
|
continue
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
2026-01-04 19:49:59 +08:00
|
|
|
|
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
2025-12-26 03:49:55 -08:00
|
|
|
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
|
|
|
|
}
|
|
|
|
|
|
flusher.Flush()
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
2025-12-26 03:49:55 -08:00
|
|
|
|
if !openaiSSEDataRe.MatchString(line) {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if data == "" || data == "[DONE]" {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var event map[string]any
|
|
|
|
|
|
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Replace model in response
|
|
|
|
|
|
if m, ok := event["model"].(string); ok && m == fromModel {
|
|
|
|
|
|
event["model"] = toModel
|
|
|
|
|
|
newData, err := json.Marshal(event)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
return "data: " + string(newData)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Check nested response
|
|
|
|
|
|
if response, ok := event["response"].(map[string]any); ok {
|
|
|
|
|
|
if m, ok := response["model"].(string); ok && m == fromModel {
|
|
|
|
|
|
response["model"] = toModel
|
|
|
|
|
|
newData, err := json.Marshal(event)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
return "data: " + string(newData)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return line
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-15 23:52:50 +08:00
|
|
|
|
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
|
|
|
|
|
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
|
|
|
|
|
if len(body) == 0 {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bodyStr := string(body)
|
|
|
|
|
|
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
|
|
|
|
|
|
if changed {
|
|
|
|
|
|
return []byte(corrected)
|
|
|
|
|
|
}
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
|
|
|
|
|
// Parse response.completed event for usage (OpenAI Responses format)
|
|
|
|
|
|
var event struct {
|
|
|
|
|
|
Type string `json:"type"`
|
|
|
|
|
|
Response struct {
|
|
|
|
|
|
Usage struct {
|
|
|
|
|
|
InputTokens int `json:"input_tokens"`
|
|
|
|
|
|
OutputTokens int `json:"output_tokens"`
|
|
|
|
|
|
InputTokenDetails struct {
|
|
|
|
|
|
CachedTokens int `json:"cached_tokens"`
|
|
|
|
|
|
} `json:"input_tokens_details"`
|
|
|
|
|
|
} `json:"usage"`
|
|
|
|
|
|
} `json:"response"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
|
|
|
|
|
usage.InputTokens = event.Response.Usage.InputTokens
|
|
|
|
|
|
usage.OutputTokens = event.Response.Usage.OutputTokens
|
|
|
|
|
|
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-26 15:40:24 +08:00
|
|
|
|
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-09 18:35:58 +08:00
|
|
|
|
if account.Type == AccountTypeOAuth {
|
|
|
|
|
|
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
|
|
|
|
|
|
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
|
|
|
|
|
|
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Parse usage
|
|
|
|
|
|
var response struct {
|
|
|
|
|
|
Usage struct {
|
|
|
|
|
|
InputTokens int `json:"input_tokens"`
|
|
|
|
|
|
OutputTokens int `json:"output_tokens"`
|
|
|
|
|
|
InputTokenDetails struct {
|
|
|
|
|
|
CachedTokens int `json:"cached_tokens"`
|
|
|
|
|
|
} `json:"input_tokens_details"`
|
|
|
|
|
|
} `json:"usage"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(body, &response); err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("parse response: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
usage := &OpenAIUsage{
|
|
|
|
|
|
InputTokens: response.Usage.InputTokens,
|
|
|
|
|
|
OutputTokens: response.Usage.OutputTokens,
|
|
|
|
|
|
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Replace model in response if needed
|
|
|
|
|
|
if originalModel != mappedModel {
|
|
|
|
|
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-02 17:40:57 +08:00
|
|
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
2026-01-05 13:54:43 +08:00
|
|
|
|
contentType := "application/json"
|
|
|
|
|
|
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
|
|
|
|
|
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
|
|
|
|
|
contentType = upstreamType
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
c.Data(resp.StatusCode, contentType, body)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
return usage, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-09 18:35:58 +08:00
|
|
|
|
func isEventStreamResponse(header http.Header) bool {
|
|
|
|
|
|
contentType := strings.ToLower(header.Get("Content-Type"))
|
|
|
|
|
|
return strings.Contains(contentType, "text/event-stream")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
|
|
|
|
|
bodyText := string(body)
|
|
|
|
|
|
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
|
|
|
|
|
|
|
|
|
|
|
usage := &OpenAIUsage{}
|
|
|
|
|
|
if ok {
|
|
|
|
|
|
var response struct {
|
|
|
|
|
|
Usage struct {
|
|
|
|
|
|
InputTokens int `json:"input_tokens"`
|
|
|
|
|
|
OutputTokens int `json:"output_tokens"`
|
|
|
|
|
|
InputTokenDetails struct {
|
|
|
|
|
|
CachedTokens int `json:"cached_tokens"`
|
|
|
|
|
|
} `json:"input_tokens_details"`
|
|
|
|
|
|
} `json:"usage"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(finalResponse, &response); err == nil {
|
|
|
|
|
|
usage.InputTokens = response.Usage.InputTokens
|
|
|
|
|
|
usage.OutputTokens = response.Usage.OutputTokens
|
|
|
|
|
|
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
|
|
|
|
|
}
|
|
|
|
|
|
body = finalResponse
|
|
|
|
|
|
if originalModel != mappedModel {
|
|
|
|
|
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
|
|
|
|
|
}
|
2026-01-15 23:52:50 +08:00
|
|
|
|
// Correct tool calls in final response
|
|
|
|
|
|
body = s.correctToolCallsInResponseBody(body)
|
2026-01-09 18:35:58 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
usage = s.parseSSEUsageFromBody(bodyText)
|
|
|
|
|
|
if originalModel != mappedModel {
|
|
|
|
|
|
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
body = []byte(bodyText)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
|
|
|
|
|
|
|
|
|
|
|
contentType := "application/json; charset=utf-8"
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
contentType = resp.Header.Get("Content-Type")
|
|
|
|
|
|
if contentType == "" {
|
|
|
|
|
|
contentType = "text/event-stream"
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
c.Data(resp.StatusCode, contentType, body)
|
|
|
|
|
|
|
|
|
|
|
|
return usage, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
|
|
|
|
|
lines := strings.Split(body, "\n")
|
|
|
|
|
|
for _, line := range lines {
|
|
|
|
|
|
if !openaiSSEDataRe.MatchString(line) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
|
|
|
|
if data == "" || data == "[DONE]" {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
var event struct {
|
|
|
|
|
|
Type string `json:"type"`
|
|
|
|
|
|
Response json.RawMessage `json:"response"`
|
|
|
|
|
|
}
|
|
|
|
|
|
if json.Unmarshal([]byte(data), &event) != nil {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if event.Type == "response.done" || event.Type == "response.completed" {
|
|
|
|
|
|
if len(event.Response) > 0 {
|
|
|
|
|
|
return event.Response, true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
|
|
|
|
|
usage := &OpenAIUsage{}
|
|
|
|
|
|
lines := strings.Split(body, "\n")
|
|
|
|
|
|
for _, line := range lines {
|
|
|
|
|
|
if !openaiSSEDataRe.MatchString(line) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
|
|
|
|
if data == "" || data == "[DONE]" {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
s.parseSSEUsage(data, usage)
|
|
|
|
|
|
}
|
|
|
|
|
|
return usage
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
|
|
|
|
|
lines := strings.Split(body, "\n")
|
|
|
|
|
|
for i, line := range lines {
|
|
|
|
|
|
if !openaiSSEDataRe.MatchString(line) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
|
|
|
|
|
}
|
|
|
|
|
|
return strings.Join(lines, "\n")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-02 17:40:57 +08:00
|
|
|
|
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
2026-01-05 13:54:43 +08:00
|
|
|
|
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
2026-01-05 14:41:08 +08:00
|
|
|
|
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", fmt.Errorf("invalid base_url: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return normalized, nil
|
2026-01-05 13:54:43 +08:00
|
|
|
|
}
|
2026-01-02 17:40:57 +08:00
|
|
|
|
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
|
|
|
|
|
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
|
|
|
|
|
RequireAllowlist: true,
|
|
|
|
|
|
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
|
|
|
|
|
})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", fmt.Errorf("invalid base_url: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return normalized, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
|
|
|
|
|
var resp map[string]any
|
|
|
|
|
|
if err := json.Unmarshal(body, &resp); err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
model, ok := resp["model"].(string)
|
|
|
|
|
|
if !ok || model != fromModel {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
resp["model"] = toModel
|
|
|
|
|
|
newBody, err := json.Marshal(resp)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return newBody
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// OpenAIRecordUsageInput input for recording usage
|
|
|
|
|
|
type OpenAIRecordUsageInput struct {
|
|
|
|
|
|
Result *OpenAIForwardResult
|
2026-01-04 19:27:53 +08:00
|
|
|
|
APIKey *APIKey
|
2025-12-26 15:40:24 +08:00
|
|
|
|
User *User
|
|
|
|
|
|
Account *Account
|
|
|
|
|
|
Subscription *UserSubscription
|
2026-01-06 16:23:56 +08:00
|
|
|
|
UserAgent string // 请求的 User-Agent
|
2026-01-09 21:59:32 +08:00
|
|
|
|
IPAddress string // 请求的客户端 IP 地址
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// RecordUsage records usage and deducts balance
|
|
|
|
|
|
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
|
|
|
|
|
result := input.Result
|
2026-01-04 19:27:53 +08:00
|
|
|
|
apiKey := input.APIKey
|
2025-12-22 22:58:31 +08:00
|
|
|
|
user := input.User
|
|
|
|
|
|
account := input.Account
|
|
|
|
|
|
subscription := input.Subscription
|
|
|
|
|
|
|
2025-12-23 10:01:58 +08:00
|
|
|
|
// 计算实际的新输入token(减去缓存读取的token)
|
|
|
|
|
|
// 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费
|
|
|
|
|
|
actualInputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens
|
|
|
|
|
|
if actualInputTokens < 0 {
|
|
|
|
|
|
actualInputTokens = 0
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Calculate cost
|
|
|
|
|
|
tokens := UsageTokens{
|
2025-12-23 10:01:58 +08:00
|
|
|
|
InputTokens: actualInputTokens,
|
2025-12-22 22:58:31 +08:00
|
|
|
|
OutputTokens: result.Usage.OutputTokens,
|
|
|
|
|
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|
|
|
|
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Get rate multiplier
|
|
|
|
|
|
multiplier := s.cfg.Default.RateMultiplier
|
|
|
|
|
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
|
|
|
|
|
multiplier = apiKey.Group.RateMultiplier
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
cost = &CostBreakdown{ActualCost: 0}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Determine billing type
|
|
|
|
|
|
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
2025-12-26 15:40:24 +08:00
|
|
|
|
billingType := BillingTypeBalance
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if isSubscriptionBilling {
|
2025-12-26 15:40:24 +08:00
|
|
|
|
billingType = BillingTypeSubscription
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Create usage log
|
|
|
|
|
|
durationMs := int(result.Duration.Milliseconds())
|
2026-01-15 15:14:44 +08:00
|
|
|
|
accountRateMultiplier := account.BillingRateMultiplier()
|
2025-12-26 15:40:24 +08:00
|
|
|
|
usageLog := &UsageLog{
|
2026-01-15 15:14:44 +08:00
|
|
|
|
UserID: user.ID,
|
|
|
|
|
|
APIKeyID: apiKey.ID,
|
|
|
|
|
|
AccountID: account.ID,
|
|
|
|
|
|
RequestID: result.RequestID,
|
|
|
|
|
|
Model: result.Model,
|
|
|
|
|
|
InputTokens: actualInputTokens,
|
|
|
|
|
|
OutputTokens: result.Usage.OutputTokens,
|
|
|
|
|
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|
|
|
|
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|
|
|
|
|
InputCost: cost.InputCost,
|
|
|
|
|
|
OutputCost: cost.OutputCost,
|
|
|
|
|
|
CacheCreationCost: cost.CacheCreationCost,
|
|
|
|
|
|
CacheReadCost: cost.CacheReadCost,
|
|
|
|
|
|
TotalCost: cost.TotalCost,
|
|
|
|
|
|
ActualCost: cost.ActualCost,
|
|
|
|
|
|
RateMultiplier: multiplier,
|
|
|
|
|
|
AccountRateMultiplier: &accountRateMultiplier,
|
|
|
|
|
|
BillingType: billingType,
|
|
|
|
|
|
Stream: result.Stream,
|
|
|
|
|
|
DurationMs: &durationMs,
|
|
|
|
|
|
FirstTokenMs: result.FirstTokenMs,
|
|
|
|
|
|
CreatedAt: time.Now(),
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-06 16:23:56 +08:00
|
|
|
|
// 添加 UserAgent
|
|
|
|
|
|
if input.UserAgent != "" {
|
|
|
|
|
|
usageLog.UserAgent = &input.UserAgent
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-09 21:59:32 +08:00
|
|
|
|
// 添加 IPAddress
|
|
|
|
|
|
if input.IPAddress != "" {
|
|
|
|
|
|
usageLog.IPAddress = &input.IPAddress
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
if apiKey.GroupID != nil {
|
|
|
|
|
|
usageLog.GroupID = apiKey.GroupID
|
|
|
|
|
|
}
|
|
|
|
|
|
if subscription != nil {
|
|
|
|
|
|
usageLog.SubscriptionID = &subscription.ID
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-03 17:10:32 -08:00
|
|
|
|
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
2025-12-29 03:17:25 +08:00
|
|
|
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|
|
|
|
|
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
|
|
|
|
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-03 17:10:32 -08:00
|
|
|
|
shouldBill := inserted || err != nil
|
|
|
|
|
|
|
2025-12-22 22:58:31 +08:00
|
|
|
|
// Deduct based on billing type
|
|
|
|
|
|
if isSubscriptionBilling {
|
2026-01-03 17:10:32 -08:00
|
|
|
|
if shouldBill && cost.TotalCost > 0 {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
2025-12-31 08:50:12 +08:00
|
|
|
|
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
} else {
|
2026-01-03 17:10:32 -08:00
|
|
|
|
if shouldBill && cost.ActualCost > 0 {
|
2025-12-22 22:58:31 +08:00
|
|
|
|
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
2025-12-31 08:50:12 +08:00
|
|
|
|
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-28 08:07:15 +08:00
|
|
|
|
// Schedule batch update for account last_used_at
|
|
|
|
|
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
2025-12-22 22:58:31 +08:00
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2025-12-23 16:26:07 +08:00
|
|
|
|
|
|
|
|
|
|
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
|
|
|
|
|
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
|
|
|
|
|
snapshot := &OpenAICodexUsageSnapshot{}
|
|
|
|
|
|
hasData := false
|
|
|
|
|
|
|
|
|
|
|
|
// Helper to parse float64 from header
|
|
|
|
|
|
parseFloat := func(key string) *float64 {
|
|
|
|
|
|
if v := headers.Get(key); v != "" {
|
|
|
|
|
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
|
|
|
|
|
return &f
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Helper to parse int from header
|
|
|
|
|
|
parseInt := func(key string) *int {
|
|
|
|
|
|
if v := headers.Get(key); v != "" {
|
|
|
|
|
|
if i, err := strconv.Atoi(v); err == nil {
|
|
|
|
|
|
return &i
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Primary (weekly) limits
|
|
|
|
|
|
if v := parseFloat("x-codex-primary-used-percent"); v != nil {
|
|
|
|
|
|
snapshot.PrimaryUsedPercent = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil {
|
|
|
|
|
|
snapshot.PrimaryResetAfterSeconds = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
if v := parseInt("x-codex-primary-window-minutes"); v != nil {
|
|
|
|
|
|
snapshot.PrimaryWindowMinutes = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Secondary (5h) limits
|
|
|
|
|
|
if v := parseFloat("x-codex-secondary-used-percent"); v != nil {
|
|
|
|
|
|
snapshot.SecondaryUsedPercent = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil {
|
|
|
|
|
|
snapshot.SecondaryResetAfterSeconds = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
if v := parseInt("x-codex-secondary-window-minutes"); v != nil {
|
|
|
|
|
|
snapshot.SecondaryWindowMinutes = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Overflow ratio
|
|
|
|
|
|
if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil {
|
|
|
|
|
|
snapshot.PrimaryOverSecondaryPercent = v
|
|
|
|
|
|
hasData = true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if !hasData {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
snapshot.UpdatedAt = time.Now().Format(time.RFC3339)
|
|
|
|
|
|
return snapshot
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
|
|
|
|
|
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
|
|
|
|
|
if snapshot == nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Convert snapshot to map for merging into Extra
|
|
|
|
|
|
updates := make(map[string]any)
|
|
|
|
|
|
if snapshot.PrimaryUsedPercent != nil {
|
|
|
|
|
|
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryResetAfterSeconds != nil {
|
|
|
|
|
|
updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryWindowMinutes != nil {
|
|
|
|
|
|
updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryUsedPercent != nil {
|
|
|
|
|
|
updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryResetAfterSeconds != nil {
|
|
|
|
|
|
updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryWindowMinutes != nil {
|
|
|
|
|
|
updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryOverSecondaryPercent != nil {
|
|
|
|
|
|
updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
|
|
|
|
|
|
2025-12-25 17:00:02 +08:00
|
|
|
|
// Normalize to canonical 5h/7d fields based on window_minutes
|
|
|
|
|
|
// This fixes the issue where OpenAI's primary/secondary naming is reversed
|
|
|
|
|
|
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
|
|
|
|
|
|
|
|
|
|
|
|
// IMPORTANT: We can only reliably determine window type from window_minutes field
|
|
|
|
|
|
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
|
|
|
|
|
|
|
|
|
|
|
|
var primaryWindowMins, secondaryWindowMins int
|
|
|
|
|
|
var hasPrimaryWindow, hasSecondaryWindow bool
|
|
|
|
|
|
|
|
|
|
|
|
// Only use window_minutes for reliable window size comparison
|
|
|
|
|
|
if snapshot.PrimaryWindowMinutes != nil {
|
|
|
|
|
|
primaryWindowMins = *snapshot.PrimaryWindowMinutes
|
|
|
|
|
|
hasPrimaryWindow = true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if snapshot.SecondaryWindowMinutes != nil {
|
|
|
|
|
|
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
|
|
|
|
|
|
hasSecondaryWindow = true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Determine which is 5h and which is 7d
|
|
|
|
|
|
var use5hFromPrimary, use7dFromPrimary bool
|
|
|
|
|
|
var use5hFromSecondary, use7dFromSecondary bool
|
|
|
|
|
|
|
|
|
|
|
|
if hasPrimaryWindow && hasSecondaryWindow {
|
|
|
|
|
|
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
|
|
|
|
|
|
if primaryWindowMins < secondaryWindowMins {
|
|
|
|
|
|
use5hFromPrimary = true
|
|
|
|
|
|
use7dFromSecondary = true
|
|
|
|
|
|
} else {
|
|
|
|
|
|
use5hFromSecondary = true
|
|
|
|
|
|
use7dFromPrimary = true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if hasPrimaryWindow {
|
|
|
|
|
|
// Only primary window size known: classify by absolute threshold
|
|
|
|
|
|
if primaryWindowMins <= 360 {
|
|
|
|
|
|
use5hFromPrimary = true
|
|
|
|
|
|
} else {
|
|
|
|
|
|
use7dFromPrimary = true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if hasSecondaryWindow {
|
|
|
|
|
|
// Only secondary window size known: classify by absolute threshold
|
|
|
|
|
|
if secondaryWindowMins <= 360 {
|
|
|
|
|
|
use5hFromSecondary = true
|
|
|
|
|
|
} else {
|
|
|
|
|
|
use7dFromSecondary = true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// No window_minutes available: cannot reliably determine window types
|
|
|
|
|
|
// Fall back to legacy assumption (may be incorrect)
|
|
|
|
|
|
// Assume primary=7d, secondary=5h based on historical observation
|
|
|
|
|
|
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
|
|
|
|
|
|
use5hFromSecondary = true
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
|
|
|
|
|
|
use7dFromPrimary = true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Write canonical 5h fields
|
|
|
|
|
|
if use5hFromPrimary {
|
|
|
|
|
|
if snapshot.PrimaryUsedPercent != nil {
|
|
|
|
|
|
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryResetAfterSeconds != nil {
|
|
|
|
|
|
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryWindowMinutes != nil {
|
|
|
|
|
|
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if use5hFromSecondary {
|
|
|
|
|
|
if snapshot.SecondaryUsedPercent != nil {
|
|
|
|
|
|
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryResetAfterSeconds != nil {
|
|
|
|
|
|
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryWindowMinutes != nil {
|
|
|
|
|
|
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Write canonical 7d fields
|
|
|
|
|
|
if use7dFromPrimary {
|
|
|
|
|
|
if snapshot.PrimaryUsedPercent != nil {
|
|
|
|
|
|
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryResetAfterSeconds != nil {
|
|
|
|
|
|
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.PrimaryWindowMinutes != nil {
|
|
|
|
|
|
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if use7dFromSecondary {
|
|
|
|
|
|
if snapshot.SecondaryUsedPercent != nil {
|
|
|
|
|
|
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryResetAfterSeconds != nil {
|
|
|
|
|
|
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
if snapshot.SecondaryWindowMinutes != nil {
|
|
|
|
|
|
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-23 16:26:07 +08:00
|
|
|
|
// Update account's Extra field asynchronously
|
|
|
|
|
|
go func() {
|
|
|
|
|
|
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
|
|
|
|
|
}()
|
|
|
|
|
|
}
|