mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
fix: 补齐旧账号的 OpenAI 限流补偿
This commit is contained in:
@@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
|
||||
@@ -1349,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range accounts {
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
|
||||
}
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1447,10 +1447,20 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
account, err = s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
} else {
|
||||
account, err = s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil || account == nil {
|
||||
return account, err
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, time.Now())
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
@@ -3923,6 +3933,45 @@ func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallb
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexRateLimitResetAtFromExtra(extra map[string]any, now time.Time) *time.Time {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "7d", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(extra, "5h", now); progress != nil && codexUsagePercentExhausted(&progress.Utilization) && progress.ResetsAt != nil && now.Before(*progress.ResetsAt) {
|
||||
resetAt := progress.ResetsAt.UTC()
|
||||
return &resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyOpenAICodexRateLimitFromExtra(account *Account, now time.Time) (*time.Time, bool) {
|
||||
if account == nil || !account.IsOpenAI() {
|
||||
return nil, false
|
||||
}
|
||||
resetAt := codexRateLimitResetAtFromExtra(account.Extra, now)
|
||||
if resetAt == nil {
|
||||
return nil, false
|
||||
}
|
||||
if account.RateLimitResetAt != nil && now.Before(*account.RateLimitResetAt) && !account.RateLimitResetAt.Before(*resetAt) {
|
||||
return account.RateLimitResetAt, false
|
||||
}
|
||||
account.RateLimitResetAt = resetAt
|
||||
return resetAt, true
|
||||
}
|
||||
|
||||
func syncOpenAICodexRateLimitFromExtra(ctx context.Context, repo AccountRepository, account *Account, now time.Time) *time.Time {
|
||||
resetAt, changed := applyOpenAICodexRateLimitFromExtra(account, now)
|
||||
if !changed || resetAt == nil || repo == nil || account == nil || account.ID <= 0 {
|
||||
return resetAt
|
||||
}
|
||||
_ = repo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
return resetAt
|
||||
}
|
||||
|
||||
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
|
||||
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
|
||||
if snapshot == nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -28,6 +29,11 @@ type openAICodexSnapshotAsyncRepo struct {
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
type openAICodexExtraListRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
|
||||
return nil
|
||||
@@ -60,6 +66,22 @@ func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, u
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
_ = platform
|
||||
_ = accountType
|
||||
_ = status
|
||||
_ = search
|
||||
_ = groupID
|
||||
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -386,6 +408,69 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
|
||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||
func ptrIntWS(v int) *int { return &v }
|
||||
|
||||
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) {
|
||||
resetAt := time.Now().Add(6 * 24 * time.Hour)
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
|
||||
fresh, err := svc.getSchedulableAccount(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fresh)
|
||||
require.NotNil(t, fresh.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待旧快照补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) {
|
||||
resetAt := time.Now().Add(4 * 24 * time.Hour)
|
||||
repo := &openAICodexExtraListRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{
|
||||
ID: 702,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}}},
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Len(t, accounts, 1)
|
||||
require.NotNil(t, accounts[0].RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待列表补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached"))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))
|
||||
|
||||
Reference in New Issue
Block a user