mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-27 09:54:47 +08:00
fix: 补齐旧账号的 OpenAI 限流补偿
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -28,6 +29,11 @@ type openAICodexSnapshotAsyncRepo struct {
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
type openAICodexExtraListRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
|
||||
return nil
|
||||
@@ -60,6 +66,22 @@ func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, u
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
_ = platform
|
||||
_ = accountType
|
||||
_ = status
|
||||
_ = search
|
||||
_ = groupID
|
||||
return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -386,6 +408,69 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
|
||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||
func ptrIntWS(v int) *int { return &v }
|
||||
|
||||
func TestOpenAIGatewayService_GetSchedulableAccount_ExhaustedCodexExtraSetsRateLimit(t *testing.T) {
|
||||
resetAt := time.Now().Add(6 * 24 * time.Hour)
|
||||
account := Account{
|
||||
ID: 701,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
repo := &openAICodexExtraListRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, rateLimitCh: make(chan time.Time, 1)}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
|
||||
fresh, err := svc.getSchedulableAccount(context.Background(), account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fresh)
|
||||
require.NotNil(t, fresh.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *fresh.RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待旧快照补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount(t *testing.T) {
|
||||
resetAt := time.Now().Add(4 * 24 * time.Hour)
|
||||
repo := &openAICodexExtraListRepo{
|
||||
stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{{
|
||||
ID: 702,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
}}},
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Len(t, accounts, 1)
|
||||
require.NotNil(t, accounts[0].RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt.UTC(), *accounts[0].RateLimitResetAt, time.Second)
|
||||
select {
|
||||
case persisted := <-repo.rateLimitCh:
|
||||
require.WithinDuration(t, resetAt.UTC(), persisted, time.Second)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("等待列表补写限流状态超时")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached"))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", ""))
|
||||
|
||||
Reference in New Issue
Block a user