mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-17 13:24:45 +08:00
Merge pull request #842 from pkssssss/fix/openai-ws-usage-refresh
fix: 修复 OpenAI WS 用量窗口刷新与限额状态不同步
This commit is contained in:
@@ -925,6 +925,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1040,6 +1041,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -406,8 +406,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
|
||||
102
backend/internal/service/account_test_service_openai_test.go
Normal file
102
backend/internal/service/account_test_service_openai_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
r.rateLimitedAt = &resetAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
|
||||
`))
|
||||
resp.Header.Set("x-codex-primary-used-percent", "88")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "42")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 89,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
}
|
||||
@@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
@@ -367,7 +368,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
mergeAccountExtra(account, updates)
|
||||
if usage.UpdatedAt == nil {
|
||||
@@ -409,6 +410,40 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if usage == nil {
|
||||
return true
|
||||
}
|
||||
if usage.FiveHour == nil || usage.SevenDay == nil {
|
||||
return true
|
||||
}
|
||||
if account.IsRateLimited() {
|
||||
return true
|
||||
}
|
||||
return isOpenAICodexSnapshotStale(account, now)
|
||||
}
|
||||
|
||||
func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
|
||||
if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return false
|
||||
}
|
||||
if account.Extra == nil {
|
||||
return true
|
||||
}
|
||||
raw, ok := account.Extra["codex_usage_updated_at"]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ts, err := parseTime(fmt.Sprint(raw))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return now.Sub(ts) >= openAIProbeCacheTTL
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
@@ -478,20 +513,34 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||
if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
68
backend/internal/service/account_usage_service_test.go
Normal file
68
backend/internal/service/account_usage_service_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(5 * time.Minute)
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
FiveHour: &UsageProgress{Utilization: 0},
|
||||
SevenDay: &UsageProgress{Utilization: 0},
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
|
||||
t.Fatal("expected rate-limited account to force codex snapshot refresh")
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
|
||||
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
|
||||
t.Fatal("expected missing 5h snapshot to require refresh")
|
||||
}
|
||||
|
||||
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"codex_usage_updated_at": staleAt,
|
||||
},
|
||||
}, usage, now) {
|
||||
t.Fatal("expected stale ws snapshot to trigger refresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if got := updates["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
@@ -1349,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range accounts {
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
|
||||
}
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -687,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
@@ -705,16 +709,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
|
||||
@@ -12,6 +12,78 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAISnapshotCacheStub struct {
|
||||
SchedulerCache
|
||||
snapshotAccounts []*Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
|
||||
if len(s.snapshotAccounts) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
out := make([]*Account, 0, len(s.snapshotAccounts))
|
||||
for _, account := range s.snapshotAccounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
cloned := *account
|
||||
out = append(out, &cloned)
|
||||
}
|
||||
return out, true, nil
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.accountsByID == nil {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.accountsByID[accountID]
|
||||
if account == nil {
|
||||
return nil, nil
|
||||
}
|
||||
cloned := *account
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10101)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(31002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10102)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
|
||||
staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(32002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
|
||||
@@ -1026,7 +1026,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
||||
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -1099,7 +1099,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
@@ -1111,27 +1111,20 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
selected = fresh
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterAccount(acc, selected) {
|
||||
selected = acc
|
||||
if s.isBetterAccount(fresh, selected) {
|
||||
selected = fresh
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1309,13 +1302,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, false)
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
@@ -1359,13 +1356,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
@@ -1377,11 +1378,15 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
// ============ Layer 3: Fallback wait ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, false)
|
||||
for _, acc := range candidates {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
@@ -1418,11 +1423,44 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
|
||||
fresh := account
|
||||
if s.schedulerSnapshot != nil {
|
||||
current, err := s.getSchedulableAccount(ctx, account.ID)
|
||||
if err != nil || current == nil {
|
||||
return nil
|
||||
}
|
||||
fresh = current
|
||||
}
|
||||
|
||||
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
if s.schedulerSnapshot != nil {
|
||||
account, err = s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
} else {
|
||||
account, err = s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
if err != nil || account == nil {
|
||||
return account, err
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now())
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
@@ -3871,6 +3909,69 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow
|
||||
return updates
|
||||
}
|
||||
|
||||
func codexUsagePercentExhausted(value *float64) bool {
|
||||
return value != nil && *value >= 100-1e-9
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
return nil
|
||||
}
|
||||
baseTime := codexSnapshotBaseTime(snapshot, fallbackNow)
|
||||
if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil {
|
||||
resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
||||
return &resetAt
|
||||
}
|
||||
if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil {
|
||||
resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) {
|
||||
if account == nil || !account.IsOpenAI() {
|
||||
return nil, false
|
||||
}
|
||||
resetAt := codexRateLimitResetAtFromExtra(account.Extra, now)
|
||||
if resetAt == nil {
|
||||
return nil, false
|
||||
}
|
||||
if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) {
|
||||
return account.RateLimitResetAt, false
|
||||
}
|
||||
account.RateLimitResetAt = resetAt
|
||||
return resetAt, true
|
||||
}
|
||||
|
||||
func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time {
|
||||
resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now)
|
||||
if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 {
|
||||
return resetAt
|
||||
}
|
||||
_ = repo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
return resetAt
|
||||
}
|
||||
|
||||
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
||||
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
||||
if snapshot == nil {
|
||||
@@ -3880,16 +3981,22 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
return
|
||||
}
|
||||
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) == 0 {
|
||||
now := time.Now()
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, now)
|
||||
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now)
|
||||
if len(updates) == 0 && resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update account's Extra field asynchronously
|
||||
go func() {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
if len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}
|
||||
if resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,43 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
account := Account{
|
||||
ID: 12,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
RateLimitResetAt: &rateLimitedUntil,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
|
||||
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
|
||||
require.NoError(t, getErr)
|
||||
require.Zero(t, boundAccountID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
|
||||
@@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
|
||||
}
|
||||
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
|
||||
}
|
||||
defer lease.Release()
|
||||
@@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errMsg := strings.TrimSpace(errMsgRaw)
|
||||
if errMsg == "" {
|
||||
errMsg = "Upstream websocket error"
|
||||
@@ -2639,6 +2644,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
wsPath,
|
||||
account.ProxyID != nil && account.Proxy != nil,
|
||||
)
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
|
||||
}
|
||||
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
||||
return nil, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
@@ -2777,6 +2786,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
|
||||
@@ -3604,6 +3614,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
|
||||
|
||||
if eventType == "error" {
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
errMsg := strings.TrimSpace(errMsgRaw)
|
||||
if errMsg == "" {
|
||||
errMsg = "OpenAI websocket prewarm error"
|
||||
@@ -3798,7 +3809,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
||||
if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() {
|
||||
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
|
||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||
return nil, nil
|
||||
}
|
||||
@@ -3867,6 +3878,36 @@ func classifyOpenAIWSAcquireError(err error) string {
|
||||
return "acquire_conn"
|
||||
}
|
||||
|
||||
func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool {
|
||||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||||
msg := strings.ToLower(strings.TrimSpace(msgRaw))
|
||||
|
||||
if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) {
|
||||
if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI {
|
||||
return
|
||||
}
|
||||
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
}
|
||||
|
||||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||||
code := strings.ToLower(strings.TrimSpace(codeRaw))
|
||||
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
|
||||
@@ -3882,6 +3923,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
case "previous_response_not_found":
|
||||
return "previous_response_not_found", true
|
||||
}
|
||||
if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return "upstream_rate_limited", false
|
||||
}
|
||||
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
|
||||
return "upgrade_required", true
|
||||
}
|
||||
@@ -3927,9 +3971,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
|
||||
case strings.Contains(errType, "permission"),
|
||||
strings.Contains(code, "forbidden"):
|
||||
return http.StatusForbidden
|
||||
case strings.Contains(errType, "rate_limit"),
|
||||
strings.Contains(code, "rate_limit"),
|
||||
strings.Contains(code, "insufficient_quota"):
|
||||
case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""):
|
||||
return http.StatusTooManyRequests
|
||||
default:
|
||||
return http.StatusBadGateway
|
||||
|
||||
477
backend/internal/service/openai_ws_ratelimit_signal_test.go
Normal file
477
backend/internal/service/openai_ws_ratelimit_signal_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIWSRateLimitSignalRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCalls []time.Time
|
||||
updateExtra []map[string]any
|
||||
}
|
||||
|
||||
type openAICodexSnapshotAsyncRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
updateExtraCh chan map[string]any
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
type openAICodexExtraListRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtra = append(r.updateExtra, copied)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
if r.updateExtraCh != nil {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtraCh <- copied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
_ = platform
|
||||
_ = accountType
|
||||
_ = status
|
||||
_ = search
|
||||
_ = groupID
|
||||
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
resetAt := time.Now().Add(2 * time.Hour).Unix()
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "rate_limit_exceeded",
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
"resets_at": resetAt,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
|
||||
account := Account{
|
||||
ID: 501,
|
||||
Name: "openai-ws-rate-limit-event",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, &account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("x-codex-primary-used-percent", "100")
|
||||
w.Header().Set("x-codex-primary-reset-after-seconds", "7200")
|
||||
w.Header().Set("x-codex-primary-window-minutes", "10080")
|
||||
w.Header().Set("x-codex-secondary-used-percent", "3")
|
||||
w.Header().Set("x-codex-secondary-reset-after-seconds", "1800")
|
||||
w.Header().Set("x-codex-secondary-window-minutes", "300")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
|
||||
account := Account{
|
||||
ID: 502,
|
||||
Name: "openai-ws-rate-limit-handshake",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": server.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: upstream,
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, &account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP")
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库")
|
||||
require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
resetAt := time.Now().Add(90 * time.Minute).Unix()
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`),
|
||||
},
|
||||
}
|
||||
captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10)))
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
account := Account{
|
||||
ID: 503,
|
||||
Name: "openai-ingress-rate-limit",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}}
|
||||
rateSvc := &RateLimitService{accountRepo: repo}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
rateLimitService: rateSvc,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = clientConn.CloseNow() }()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.Error(t, serverErr)
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) {
|
||||
repo := &openAICodexSnapshotAsyncRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: ptrFloat64WS(100),
|
||||
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||
SecondaryUsedPercent: ptrFloat64WS(12),
|
||||
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||
SecondaryWindowMinutes: ptrIntWS(300),
|
||||
}
|
||||
before := time.Now()
|
||||
svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot)
|
||||
|
||||
select {
|
||||
case updates := <-repo.updateExtraCh:
|
||||
require.Equal(t, 100.0, updates["codex_7d_used_percent"])
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 快照落库超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case resetAt := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 100% 自动切换限流超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) {
|
||||
repo := &openAICodexSnapshotAsyncRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: ptrFloat64WS(94),
|
||||
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||
SecondaryUsedPercent: ptrFloat64WS(22),
|
||||
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||
SecondaryWindowMinutes: ptrIntWS(300),
|
||||
}
|
||||
svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot)
|
||||
|
||||
select {
|
||||
case <-repo.updateExtraCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待 codex 快照落库超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case resetAt := <-repo.rateLimitCh:
|
||||
t.Fatalf("unexpected rate limit reset at: %v", resetAt)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||
func ptrIntWS(v int) *int { return &v }
|
||||
|
||||
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) {
|
||||
resetAt := time.Now().Add(6 * 24 * time.Hour)
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
|
||||
fresh, err := svc.getSchedulableAccount(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fresh)
|
||||
require.NotNil(t, fresh.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待旧快照补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) {
|
||||
resetAt := time.Now().Add(4 * 24 * time.Hour)
|
||||
repo := &openAICodexExtraListRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{
|
||||
ID: 702,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}}},
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Len(t, accounts, 1)
|
||||
require.NotNil(t, accounts[0].RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待列表补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached"))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))
|
||||
}
|
||||
@@ -615,6 +615,7 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||
if account.Platform == PlatformOpenAI {
|
||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
@@ -878,6 +879,23 @@ func pickSooner(a, b *time.Time) *time.Time {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) persistOpenAICodexSnapshot(ctx context.Context, account *Account, headers http.Header) {
|
||||
if s == nil || s.accountRepo == nil || account == nil || headers == nil {
|
||||
return
|
||||
}
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, updates); err != nil {
|
||||
slog.Warn("openai_codex_snapshot_persist_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -143,6 +144,51 @@ func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type openAI429SnapshotRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
rateLimitedID int64
|
||||
updatedExtra map[string]any
|
||||
}
|
||||
|
||||
func (r *openAI429SnapshotRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAI429SnapshotRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandle429_OpenAIPersistsCodexSnapshotImmediately(t *testing.T) {
|
||||
repo := &openAI429SnapshotRepo{}
|
||||
svc := NewRateLimitService(repo, nil, nil, nil, nil)
|
||||
account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
svc.handle429(context.Background(), account, headers, nil)
|
||||
|
||||
if repo.rateLimitedID != account.ID {
|
||||
t.Fatalf("rateLimitedID = %d, want %d", repo.rateLimitedID, account.ID)
|
||||
}
|
||||
if len(repo.updatedExtra) == 0 {
|
||||
t.Fatal("expected codex snapshot to be persisted on 429")
|
||||
}
|
||||
if got := repo.updatedExtra["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := repo.updatedExtra["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits(t *testing.T) {
|
||||
// Test the Normalize() method directly
|
||||
pUsed := 100.0
|
||||
|
||||
Reference in New Issue
Block a user