From 45d57018eb1bc2a37a3aa8d2531d57ec0eb023a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A5=9E=E4=B9=90?= <6682635@qq.com> Date: Sat, 7 Mar 2026 23:59:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20OpenAI=20WS=20?= =?UTF-8?q?=E9=99=90=E6=B5=81=E7=8A=B6=E6=80=81=E4=B8=8E=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E5=90=8C=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/openai_gateway_service.go | 38 +- .../internal/service/openai_ws_forwarder.go | 48 ++- .../openai_ws_ratelimit_signal_test.go | 392 ++++++++++++++++++ 3 files changed, 471 insertions(+), 7 deletions(-) create mode 100644 backend/internal/service/openai_ws_ratelimit_signal_test.go diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index e642ff60..40a4e377 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -3899,6 +3899,30 @@ func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow return updates } +func codexUsagePercentExhausted(value *float64) bool { + return value != nil && *value >= 100-1e-9 +} + +func codexRateLimitResetAtFromSnapshot(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) *time.Time { + if snapshot == nil { + return nil + } + normalized := snapshot.Normalize() + if normalized == nil { + return nil + } + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + if codexUsagePercentExhausted(normalized.Used7dPercent) && normalized.Reset7dSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second) + return &resetAt + } + if codexUsagePercentExhausted(normalized.Used5hPercent) && normalized.Reset5hSeconds != nil { + resetAt := baseTime.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second) + return &resetAt + } + return nil +} + // updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { if snapshot == nil { @@ -3908,16 +3932,22 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc return } - updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) - if len(updates) == 0 { + now := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, now) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, now) + if len(updates) == 0 && resetAt == nil { return } - // Update account's Extra field asynchronously go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + if len(updates) > 0 { + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } }() } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 36bf8ff8..f2f8edd9 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + } return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } defer lease.Release() @@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "Upstream websocket error" @@ -2639,6 +2644,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsPath, account.ProxyID != nil && account.Proxy != nil, ) + var dialErr *openAIWSDialError + if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + } if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, @@ -2777,6 +2786,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw) fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && @@ -3604,6 +3614,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { errMsg = "OpenAI websocket prewarm error" @@ -3867,6 +3878,36 @@ func classifyOpenAIWSAcquireError(err error) string { return "acquire_conn" } +func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { + return true + } + if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { + return true + } + if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { + return true + } + if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { + return true + } + return false +} + +func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { + if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { + return + } + if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) +} + func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { code := strings.ToLower(strings.TrimSpace(codeRaw)) errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) @@ -3882,6 +3923,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri case "previous_response_not_found": return "previous_response_not_found", true } + if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return "upstream_rate_limited", false + } if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { return "upgrade_required", true } @@ -3927,9 +3971,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { case strings.Contains(errType, "permission"), strings.Contains(code, "forbidden"): return http.StatusForbidden - case strings.Contains(errType, "rate_limit"), - strings.Contains(code, "rate_limit"), - strings.Contains(code, "insufficient_quota"): + case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): return http.StatusTooManyRequests default: return http.StatusBadGateway diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go new file mode 100644 index 00000000..a6b6e874 --- /dev/null +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -0,0 +1,392 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +type openAIWSRateLimitSignalRepo struct { + stubOpenAIAccountRepo + rateLimitCalls []time.Time + updateExtra []map[string]any +} + +type openAICodexSnapshotAsyncRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +func (r *openAIWSRateLimitSignalRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + r.rateLimitCalls = append(r.rateLimitCalls, resetAt) + return nil +} + +func (r *openAIWSRateLimitSignalRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtra = append(r.updateExtra, copied) + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + +func (r *openAICodexSnapshotAsyncRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func TestOpenAIGatewayService_Forward_WSv2ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + resetAt := time.Now().Add(2 * time.Hour).Unix() + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { _ = conn.Close() }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "rate_limit_exceeded", + "type": "usage_limit_reached", + "message": "The usage limit has been reached", + "resets_at": resetAt, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 501, + Name: "openai-ws-rate-limit-event", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 限流 error event 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) +} + +func TestOpenAIGatewayService_Forward_WSv2Handshake429PersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-codex-primary-used-percent", "100") + w.Header().Set("x-codex-primary-reset-after-seconds", "7200") + w.Header().Set("x-codex-primary-window-minutes", "10080") + w.Header().Set("x-codex-secondary-used-percent", "3") + w.Header().Set("x-codex-secondary-reset-after-seconds", "1800") + w.Header().Set("x-codex-secondary-window-minutes", "300") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"type":"rate_limit_exceeded","message":"rate limited"}}`)) + })) + defer server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_should_not_run"}`)), + }, + } + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + + account := Account{ + ID: 502, + Name: "openai-ws-rate-limit-handshake", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, &account, body) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Nil(t, upstream.lastReq, "WS 握手 429 不应回退到同账号 HTTP") + require.Len(t, repo.rateLimitCalls, 1) + require.NotEmpty(t, repo.updateExtra, "握手 429 的 x-codex 头应立即落库") + require.Contains(t, repo.updateExtra[0], "codex_usage_updated_at") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageLimitPersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := newOpenAIWSV2TestConfig() + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + resetAt := time.Now().Add(90 * time.Minute).Unix() + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached","resets_at":PLACEHOLDER}}`), + }, + } + captureConn.events[0] = []byte(strings.ReplaceAll(string(captureConn.events[0]), "PLACEHOLDER", strconv.FormatInt(resetAt, 10))) + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + account := Account{ + ID: 503, + Name: "openai-ingress-rate-limit", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + repo := &openAIWSRateLimitSignalRepo{stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}} + rateSvc := &RateLimitService{accountRepo: repo} + svc := &OpenAIGatewayService{ + accountRepo: repo, + rateLimitService: rateSvc, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + cfg: cfg, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- io.ErrUnexpectedEOF + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, &account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ExhaustedSnapshotSetsRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(100), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(12), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + before := time.Now() + svc.updateCodexUsageSnapshot(context.Background(), 601, snapshot) + + select { + case updates := <-repo.updateExtraCh: + require.Equal(t, 100.0, updates["codex_7d_used_percent"]) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + require.WithinDuration(t, before.Add(time.Hour), resetAt, 2*time.Second) + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 100% 自动切换限流超时") + } +} + +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesNotSetRateLimit(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &OpenAIGatewayService{accountRepo: repo} + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + svc.updateCodexUsageSnapshot(context.Background(), 602, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待 codex 快照落库超时") + } + + select { + case resetAt := <-repo.rateLimitCh: + t.Fatalf("unexpected rate limit reset at: %v", resetAt) + case <-time.After(200 * time.Millisecond): + } +} + +func ptrFloat64WS(v float64) *float64 { return &v } +func ptrIntWS(v int) *int { return &v } + +func TestOpenAIWSErrorHTTPStatusFromRaw_UsageLimitReachedIs429(t *testing.T) { + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("", "usage_limit_reached")) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatusFromRaw("rate_limit_exceeded", "")) +}