mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
Merge pull request #931 from xvhuan/fix/db-write-amplification-20260311
降低 quota 与 Codex 快照热路径的数据库写放大
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -1185,12 +1186,117 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
} else if shouldSyncSchedulerSnapshotForExtraUpdates(updates) {
|
||||||
|
// codex 限流快照仍需要让调度缓存尽快看见,避免 DB 抖动时丢失自愈链路。
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for key := range updates {
|
||||||
|
if isSchedulerNeutralAccountExtraKey(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSyncSchedulerSnapshotForExtraUpdates(updates map[string]any) bool {
|
||||||
|
return codexExtraIndicatesRateLimit(updates, "7d") || codexExtraIndicatesRateLimit(updates, "5h")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSchedulerNeutralAccountExtraKey(key string) bool {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if key == "session_window_utilization" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(key, "codex_")
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexExtraIndicatesRateLimit(updates map[string]any, window string) bool {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
usedValue, ok := updates["codex_"+window+"_used_percent"]
|
||||||
|
if !ok || !extraValueIndicatesExhausted(usedValue) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return extraValueHasResetMarker(updates["codex_"+window+"_reset_at"]) ||
|
||||||
|
extraValueHasPositiveNumber(updates["codex_"+window+"_reset_after_seconds"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func extraValueIndicatesExhausted(value any) bool {
|
||||||
|
number, ok := extraValueToFloat64(value)
|
||||||
|
return ok && number >= 100-1e-9
|
||||||
|
}
|
||||||
|
|
||||||
|
func extraValueHasPositiveNumber(value any) bool {
|
||||||
|
number, ok := extraValueToFloat64(value)
|
||||||
|
return ok && number > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func extraValueHasResetMarker(value any) bool {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v) != ""
|
||||||
|
case time.Time:
|
||||||
|
return !v.IsZero()
|
||||||
|
case *time.Time:
|
||||||
|
return v != nil && !v.IsZero()
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extraValueToFloat64(value any) (float64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case float64:
|
||||||
|
return v, true
|
||||||
|
case float32:
|
||||||
|
return float64(v), true
|
||||||
|
case int:
|
||||||
|
return float64(v), true
|
||||||
|
case int8:
|
||||||
|
return float64(v), true
|
||||||
|
case int16:
|
||||||
|
return float64(v), true
|
||||||
|
case int32:
|
||||||
|
return float64(v), true
|
||||||
|
case int64:
|
||||||
|
return float64(v), true
|
||||||
|
case uint:
|
||||||
|
return float64(v), true
|
||||||
|
case uint8:
|
||||||
|
return float64(v), true
|
||||||
|
case uint16:
|
||||||
|
return float64(v), true
|
||||||
|
case uint32:
|
||||||
|
return float64(v), true
|
||||||
|
case uint64:
|
||||||
|
return float64(v), true
|
||||||
|
case json.Number:
|
||||||
|
parsed, err := v.Float64()
|
||||||
|
return parsed, err == nil
|
||||||
|
case string:
|
||||||
|
parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||||
|
return parsed, err == nil
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|||||||
@@ -623,6 +623,65 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
|||||||
s.Require().Equal("val", got.Extra["key"])
|
s.Require().Equal("val", got.Extra["key"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralKeysSkipOutbox() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-neutral", Extra: map[string]any{}})
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"codex_usage_updated_at": "2026-03-11T13:00:00Z",
|
||||||
|
"codex_5h_used_percent": 12.5,
|
||||||
|
"session_window_utilization": 0.42,
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(0, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "acc-extra-codex-exhausted",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||||
|
"codex_7d_reset_after_seconds": 86400,
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(0, count)
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_CustomKeysStillEnqueueOutbox() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-custom", Extra: map[string]any{}})
|
||||||
|
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||||
|
"custom_scheduler_sensitive_key": true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(1, count)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetByCRSAccountID ---
|
// --- GetByCRSAccountID ---
|
||||||
|
|
||||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||||
|
|||||||
@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
|||||||
return updated.QuotaUsed, nil
|
return updated.QuotaUsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||||
|
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||||
|
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||||
|
query := `
|
||||||
|
UPDATE api_keys
|
||||||
|
SET
|
||||||
|
quota_used = quota_used + $1,
|
||||||
|
status = CASE
|
||||||
|
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||||
|
ELSE status
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $3 AND deleted_at IS NULL
|
||||||
|
RETURNING quota_used, quota, key, status
|
||||||
|
`
|
||||||
|
|
||||||
|
state := &service.APIKeyQuotaUsageState{}
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||||
affected, err := r.client.APIKey.Update().
|
affected, err := r.client.APIKey.Update().
|
||||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||||
|
|||||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
|||||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||||
|
user := s.mustCreateUser("quota-state@test.com")
|
||||||
|
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||||
|
key.Quota = 3
|
||||||
|
key.QuotaUsed = 1
|
||||||
|
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||||
|
|
||||||
|
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||||
|
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||||
|
s.Require().NotNil(state)
|
||||||
|
s.Require().Equal(3.5, state.QuotaUsed)
|
||||||
|
s.Require().Equal(3.0, state.Quota)
|
||||||
|
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||||
|
s.Require().Equal(key.Key, state.Key)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(3.5, got.QuotaUsed)
|
||||||
|
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||||
|
}
|
||||||
|
|
||||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||||
|
|||||||
@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
|
||||||
mergeAccountExtra(account, updates)
|
mergeAccountExtra(account, updates)
|
||||||
|
if resetAt != nil {
|
||||||
|
account.RateLimitResetAt = resetAt
|
||||||
|
}
|
||||||
if usage.UpdatedAt == nil {
|
if usage.UpdatedAt == nil {
|
||||||
usage.UpdatedAt = &now
|
usage.UpdatedAt = &now
|
||||||
}
|
}
|
||||||
@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
|
||||||
if account == nil || !account.IsOAuth() {
|
if account == nil || !account.IsOAuth() {
|
||||||
return nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
accessToken := account.GetOpenAIAccessToken()
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("no access token available")
|
return nil, nil, fmt.Errorf("no access token available")
|
||||||
}
|
}
|
||||||
modelID := openaipkg.DefaultTestModel
|
modelID := openaipkg.DefaultTestModel
|
||||||
payload := createOpenAITestPayload(modelID, true)
|
payload := createOpenAITestPayload(modelID, true)
|
||||||
payloadBytes, err := json.Marshal(payload)
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
|
||||||
}
|
}
|
||||||
req.Host = "chatgpt.com"
|
req.Host = "chatgpt.com"
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
|||||||
ResponseHeaderTimeout: 10 * time.Second,
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if len(updates) > 0 {
|
if len(updates) > 0 || resetAt != nil {
|
||||||
go func(accountID int64, updates map[string]any) {
|
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
return updates, resetAt, nil
|
||||||
defer updateCancel()
|
}
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
|
||||||
|
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(updates) == 0 && resetAt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer updateCancel()
|
||||||
|
if len(updates) > 0 {
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}(account.ID, updates)
|
}
|
||||||
return updates, nil
|
if resetAt != nil {
|
||||||
|
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
|
||||||
|
if resp == nil {
|
||||||
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
return nil, nil
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
|
baseTime := time.Now()
|
||||||
|
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
|
||||||
|
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
|
||||||
|
if len(updates) > 0 {
|
||||||
|
return updates, resetAt, nil
|
||||||
|
}
|
||||||
|
return nil, resetAt, nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||||
if resp == nil {
|
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
|
||||||
return nil, nil
|
return updates, err
|
||||||
}
|
|
||||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|
||||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
|
||||||
if len(updates) > 0 {
|
|
||||||
return updates, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||||
|
|||||||
@@ -1,11 +1,36 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type accountUsageCodexProbeRepo struct {
|
||||||
|
stubOpenAIAccountRepo
|
||||||
|
updateExtraCh chan map[string]any
|
||||||
|
rateLimitCh chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||||
|
if r.updateExtraCh != nil {
|
||||||
|
copied := make(map[string]any, len(updates))
|
||||||
|
for k, v := range updates {
|
||||||
|
copied[k] = v
|
||||||
|
}
|
||||||
|
r.updateExtraCh <- copied
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||||
|
if r.rateLimitCh != nil {
|
||||||
|
r.rateLimitCh <- resetAt
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
|
|||||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("x-codex-primary-used-percent", "100")
|
||||||
|
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||||
|
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||||
|
headers.Set("x-codex-secondary-used-percent", "100")
|
||||||
|
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||||
|
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||||
|
|
||||||
|
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(updates) == 0 {
|
||||||
|
t.Fatal("expected codex probe updates from 429 headers")
|
||||||
|
}
|
||||||
|
if resetAt == nil {
|
||||||
|
t.Fatal("expected resetAt from exhausted codex headers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
repo := &accountUsageCodexProbeRepo{
|
||||||
|
updateExtraCh: make(chan map[string]any, 1),
|
||||||
|
rateLimitCh: make(chan time.Time, 1),
|
||||||
|
}
|
||||||
|
svc := &AccountUsageService{accountRepo: repo}
|
||||||
|
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||||
|
|
||||||
|
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
||||||
|
}, &resetAt)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case updates := <-repo.updateExtraCh:
|
||||||
|
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||||
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("waiting for codex probe extra persistence timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-repo.rateLimitCh:
|
||||||
|
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
|
||||||
|
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
|||||||
return d.Usage7d
|
return d.Usage7d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
|
||||||
|
// It is intentionally small so repositories can return it from a single SQL statement.
|
||||||
|
type APIKeyQuotaUsageState struct {
|
||||||
|
QuotaUsed float64
|
||||||
|
Quota float64
|
||||||
|
Key string
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
type APIKeyCache interface {
|
type APIKeyCache interface {
|
||||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||||
@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type quotaStateReader interface {
|
||||||
|
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
|
||||||
|
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("increment quota used: %w", err)
|
||||||
|
}
|
||||||
|
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, state.Key)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Use repository to atomically increment quota_used
|
// Use repository to atomically increment quota_used
|
||||||
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
170
backend/internal/service/api_key_service_quota_test.go
Normal file
170
backend/internal/service/api_key_service_quota_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type quotaStateRepoStub struct {
|
||||||
|
quotaBaseAPIKeyRepoStub
|
||||||
|
stateCalls int
|
||||||
|
state *APIKeyQuotaUsageState
|
||||||
|
stateErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
|
||||||
|
s.stateCalls++
|
||||||
|
if s.stateErr != nil {
|
||||||
|
return nil, s.stateErr
|
||||||
|
}
|
||||||
|
if s.state == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
out := *s.state
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type quotaStateCacheStub struct {
|
||||||
|
deleteAuthKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type quotaBaseAPIKeyRepoStub struct {
|
||||||
|
getByIDCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
|
||||||
|
s.getByIDCalls++
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||||
|
panic("unexpected GetKeyAndOwnerID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKey call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||||
|
panic("unexpected VerifyOwnership call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected CountByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByKey call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||||
|
panic("unexpected SearchAPIKeys call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected ClearGroupIDByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
panic("unexpected CountByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||||
|
panic("unexpected IncrementQuotaUsed call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||||
|
panic("unexpected UpdateLastUsed call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||||
|
panic("unexpected IncrementRateLimitUsage call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
|
||||||
|
panic("unexpected ResetRateLimitWindows call")
|
||||||
|
}
|
||||||
|
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||||
|
panic("unexpected GetRateLimitData call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
|
||||||
|
repo := "aStateRepoStub{
|
||||||
|
state: &APIKeyQuotaUsageState{
|
||||||
|
QuotaUsed: 12,
|
||||||
|
Quota: 10,
|
||||||
|
Key: "sk-test-quota",
|
||||||
|
Status: StatusAPIKeyQuotaExhausted,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := "aStateCacheStub{}
|
||||||
|
svc := &APIKeyService{
|
||||||
|
apiKeyRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.stateCalls)
|
||||||
|
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
|
||||||
|
}
|
||||||
@@ -52,6 +52,8 @@ const (
|
|||||||
openAIWSRetryJitterRatioDefault = 0.2
|
openAIWSRetryJitterRatioDefault = 0.2
|
||||||
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
openAICompactSessionSeedKey = "openai_compact_session_seed"
|
||||||
codexCLIVersion = "0.104.0"
|
codexCLIVersion = "0.104.0"
|
||||||
|
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
|
||||||
|
openAICodexSnapshotPersistMinInterval = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAI allowed headers whitelist (for non-passthrough).
|
// OpenAI allowed headers whitelist (for non-passthrough).
|
||||||
@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
|
|||||||
nonRetryableFastFallback atomic.Int64
|
nonRetryableFastFallback atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountWriteThrottle struct {
|
||||||
|
minInterval time.Duration
|
||||||
|
mu sync.Mutex
|
||||||
|
lastByID map[int64]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
|
||||||
|
return &accountWriteThrottle{
|
||||||
|
minInterval: minInterval,
|
||||||
|
lastByID: make(map[int64]time.Time),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
|
||||||
|
if t == nil || id <= 0 || t.minInterval <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.lastByID[id] = now
|
||||||
|
|
||||||
|
if len(t.lastByID) > 4096 {
|
||||||
|
cutoff := now.Add(-4 * t.minInterval)
|
||||||
|
for accountID, writtenAt := range t.lastByID {
|
||||||
|
if writtenAt.Before(cutoff) {
|
||||||
|
delete(t.lastByID, accountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
|
||||||
|
|
||||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||||
type OpenAIGatewayService struct {
|
type OpenAIGatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -289,6 +331,7 @@ type OpenAIGatewayService struct {
|
|||||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||||
|
codexSnapshotThrottle *accountWriteThrottle
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||||
@@ -329,17 +372,25 @@ func NewOpenAIGatewayService(
|
|||||||
nil,
|
nil,
|
||||||
"service.openai_gateway",
|
"service.openai_gateway",
|
||||||
),
|
),
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
openAITokenProvider: openAITokenProvider,
|
openAITokenProvider: openAITokenProvider,
|
||||||
toolCorrector: NewCodexToolCorrector(),
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||||
}
|
}
|
||||||
svc.logOpenAIWSModeBootstrap()
|
svc.logOpenAIWSModeBootstrap()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||||
|
if s != nil && s.codexSnapshotThrottle != nil {
|
||||||
|
return s.codexSnapshotThrottle
|
||||||
|
}
|
||||||
|
return defaultOpenAICodexSnapshotPersistThrottle
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||||
return &billingDeps{
|
return &billingDeps{
|
||||||
accountRepo: s.accountRepo,
|
accountRepo: s.accountRepo,
|
||||||
@@ -4164,11 +4215,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
|||||||
if len(updates) == 0 && resetAt == nil {
|
if len(updates) == 0 && resetAt == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
|
||||||
|
if !shouldPersistUpdates && resetAt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if len(updates) > 0 {
|
if shouldPersistUpdates {
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
}
|
}
|
||||||
if resetAt != nil {
|
if resetAt != nil {
|
||||||
|
|||||||
@@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
|
||||||
|
repo := &openAICodexSnapshotAsyncRepo{
|
||||||
|
updateExtraCh: make(chan map[string]any, 2),
|
||||||
|
rateLimitCh: make(chan time.Time, 2),
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
codexSnapshotThrottle: newAccountWriteThrottle(time.Hour),
|
||||||
|
}
|
||||||
|
snapshot := &OpenAICodexUsageSnapshot{
|
||||||
|
PrimaryUsedPercent: ptrFloat64WS(94),
|
||||||
|
PrimaryResetAfterSeconds: ptrIntWS(3600),
|
||||||
|
PrimaryWindowMinutes: ptrIntWS(10080),
|
||||||
|
SecondaryUsedPercent: ptrFloat64WS(22),
|
||||||
|
SecondaryResetAfterSeconds: ptrIntWS(1200),
|
||||||
|
SecondaryWindowMinutes: ptrIntWS(300),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
|
||||||
|
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-repo.updateExtraCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("等待第一次 codex 快照落库超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case updates := <-repo.updateExtraCh:
|
||||||
|
t.Fatalf("unexpected second codex snapshot write: %v", updates)
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ptrFloat64WS(v float64) *float64 { return &v }
|
func ptrFloat64WS(v float64) *float64 { return &v }
|
||||||
func ptrIntWS(v int) *int { return &v }
|
func ptrIntWS(v int) *int { return &v }
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user