diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 9bb3aa0b..7c001118 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou if account == nil { return usage, nil } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now) if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { usage.FiveHour = progress diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 680268e0..ab00932c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1349,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, if err != nil { return nil, 0, err } + now := time.Now() + for i := range accounts { + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) + } return accounts, result.Total, nil } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 40a4e377..709ee808 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1447,10 +1447,20 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. } func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + var ( + account *Account + err error + ) if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.GetAccount(ctx, accountID) + account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) + } else { + account, err = s.accountRepo.GetByID(ctx, accountID) } - return s.accountRepo.GetByID(ctx, accountID) + if err != nil || account == nil { + return account, err + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now()) + return account, nil } func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -3923,6 +3933,45 @@ func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallb return nil } +func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time { + if len(extra) == 0 { + return nil + } + if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) { + resetAt := progress.ResetsAt.UTC() + return &resetAt + } + return nil +} + +func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) { + if account == nil || !account.IsOpenAI() { + return nil, false + } + resetAt := codexRateLimitResetAtFromExtra(account.Extra, now) + if resetAt == nil { + return nil, false + } + if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) { + return account.RateLimitResetAt, false + } + account.RateLimitResetAt = resetAt + return resetAt, true +} + +func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time { + resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now) + if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 { + return resetAt + } + _ = repo.SetRateLimited(ctx, account.ID, *resetAt) + return resetAt +} + // updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { if snapshot == nil { diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index a6b6e874..28cb8e00 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -28,6 +29,11 @@ type openAICodexSnapshotAsyncRepo struct { rateLimitCh chan time.Time } +type openAICodexExtraListRepo struct { + stubOpenAIAccountRepo + rateLimitCh chan time.Time +} + func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { r.rateLimitCalls = append(r.rateLimitCalls, resetAt) return nil @@ -60,6 +66,22 @@ func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, u return nil } +func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { + _ = platform + _ = accountType + _ = status + _ = search + _ = groupID + return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil +} + func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { gin.SetMode(gin.TestMode) @@ -386,6 +408,69 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN func ptrFloat64WS(v float64) *float64 { return &v } func ptrIntWS(v int) *int { return &v } +func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) { + resetAt := time.Now().Add(6 * 24 * time.Hour) + account := Account{ + ID: 701, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + } + repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)} + svc := &OpenAIGatewayService{accountRepo: repo} + + fresh, err := svc.getSchedulableAccount(context.Background(), account.ID) + require.NoError(t, err) + require.NotNil(t, fresh) + require.NotNil(t, fresh.RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待旧快照补写限流状态超时") + } +} + +func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) { + resetAt := time.Now().Add(4 * 24 * time.Hour) + repo := &openAICodexExtraListRepo{ + stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{ + ID: 702, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339), + }, + }}}, + rateLimitCh: make(chan time.Time, 1), + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, accounts, 1) + require.NotNil(t, accounts[0].RateLimitResetAt) + require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second) + select { + case persisted := <-repo.rateLimitCh: + require.WithinDuration(t, resetAt.UTC(), persisted, time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待列表补写限流状态超时") + } +} + func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached")) require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))