Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312

This commit is contained in:
ius
2026-03-12 16:53:28 +08:00
76 changed files with 6148 additions and 404 deletions

View File

@@ -45,16 +45,23 @@ const (
// TestEvent represents a SSE event for account testing
type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
ImageURL string `json:"image_url,omitempty"`
MimeType string `json:"mime_type,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
ctx := c.Request.Context()
// Get account
@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
if account.IsGemini() {
return s.testGeminiAccountConnection(c, account, modelID)
return s.testGeminiAccountConnection(c, account, modelID, prompt)
}
if account.Platform == PlatformAntigravity {
return s.routeAntigravityTest(c, account, modelID)
return s.routeAntigravityTest(c, account, modelID, prompt)
}
if account.Platform == PlatformSora {
@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
// Determine the model to use
@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
c.Writer.Flush()
// Create test payload (Gemini format)
payload := createGeminiTestPayload()
payload := createGeminiTestPayload(testModelID, prompt)
// Build request based on account type
var req *http.Request
@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID)
return s.testGeminiAccountConnection(c, account, modelID, prompt)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
return req, nil
}
// createGeminiTestPayload creates a minimal test payload for Gemini API
func createGeminiTestPayload() []byte {
// createGeminiTestPayload creates a minimal test payload for Gemini API.
// Image models use the image-generation path so the frontend can preview the returned image.
func createGeminiTestPayload(modelID string, prompt string) []byte {
if isImageGenerationModel(modelID) {
imagePrompt := strings.TrimSpace(prompt)
if imagePrompt == "" {
imagePrompt = defaultGeminiImageTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": imagePrompt},
},
},
},
"generationConfig": map[string]any{
"responseModalities": []string{"TEXT", "IMAGE"},
"imageConfig": map[string]any{
"aspectRatio": "1:1",
},
},
}
bytes, _ := json.Marshal(payload)
return bytes
}
textPrompt := strings.TrimSpace(prompt)
if textPrompt == "" {
textPrompt = defaultGeminiTextTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": "hi"},
{"text": textPrompt},
},
},
},
@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
if text, ok := partMap["text"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
mimeType, _ := inlineData["mimeType"].(string)
data, _ := inlineData["data"].(string)
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
s.sendEvent(c, TestEvent{
Type: "image",
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
MimeType: mimeType,
})
}
}
}
}
}
@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
finishedAt := time.Now()
body := w.Body.String()

View File

@@ -0,0 +1,59 @@
//go:build unit
package service
import (
"encoding/json"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
t.Parallel()
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
var parsed struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
ImageConfig struct {
AspectRatio string `json:"aspectRatio"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
require.NoError(t, json.Unmarshal(payload, &parsed))
require.Len(t, parsed.Contents, 1)
require.Len(t, parsed.Contents[0].Parts, 1)
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
}
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
err := svc.processGeminiStream(ctx, stream)
require.NoError(t, err)
body := recorder.Body.String()
require.Contains(t, body, "\"type\":\"content\"")
require.Contains(t, body, "\"text\":\"ok\"")
require.Contains(t, body, "\"type\":\"image\"")
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
require.Contains(t, body, "\"mime_type\":\"image/png\"")
}

View File

@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
mergeAccountExtra(account, updates)
if resetAt != nil {
account.RateLimitResetAt = resetAt
}
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return true
}
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
if account == nil || !account.IsOAuth() {
return nil, nil
return nil, nil, nil
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
return nil, nil, fmt.Errorf("no access token available")
}
modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create openai probe request: %w", err)
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
}
req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json")
@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("build openai probe client: %w", err)
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
updates, err := extractOpenAICodexProbeUpdates(resp)
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
if err != nil {
return nil, err
return nil, nil, err
}
if len(updates) > 0 {
go func(accountID int64, updates map[string]any) {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if len(updates) > 0 || resetAt != nil {
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
return updates, resetAt, nil
}
return nil, nil, nil
}
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
if s == nil || s.accountRepo == nil || accountID <= 0 {
return
}
if len(updates) == 0 && resetAt == nil {
return
}
go func() {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}(account.ID, updates)
return updates, nil
}
if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
}()
}
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
if resp == nil {
return nil, nil, nil
}
return nil, nil
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
baseTime := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
if len(updates) > 0 {
return updates, resetAt, nil
}
return nil, resetAt, nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
if resp == nil {
return nil, nil
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
if len(updates) > 0 {
return updates, nil
}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
return updates, err
}
func mergeAccountExtra(account *Account, updates map[string]any) {

View File

@@ -1,11 +1,36 @@
package service
import (
"context"
"net/http"
"testing"
"time"
)
type accountUsageCodexProbeRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map[string]any
rateLimitCh chan time.Time
}
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
if r.updateExtraCh != nil {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtraCh <- copied
}
return nil
}
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
if r.rateLimitCh != nil {
r.rateLimitCh <- resetAt
}
return nil
}
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
t.Parallel()
@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
}
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if resetAt == nil {
t.Fatal("expected resetAt from exhausted codex headers")
}
}
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
}, &resetAt)
select {
case updates := <-repo.updateExtraCh:
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe extra persistence timed out")
}
select {
case got := <-repo.rateLimitCh:
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}

View File

@@ -2164,6 +2164,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
}
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
signatureCheckBody := respBody
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
signatureCheckBody = unwrapped
}
if resp.StatusCode == http.StatusBadRequest &&
s.settingService != nil &&
s.settingService.IsSignatureRectifierEnabled(ctx) &&
isSignatureRelatedError(signatureCheckBody) &&
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
if wrapErr == nil {
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: retryWrappedBody,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0,
sessionHash: "",
})
if retryErr == nil {
retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
resp = retryResp
} else {
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
retryOpsBody := retryRespBody
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
retryOpsBody = retryUnwrapped
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry",
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
Detail: s.getUpstreamErrorDetail(retryOpsBody),
})
respBody = retryRespBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
}
contentType = resp.Header.Get("Content-Type")
}
} else {
if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: http.StatusServiceUnavailable,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
}
} else {
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
}
}
// fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 {
goto handleSuccess

View File

@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return s.resp, s.err
}
type queuedHTTPUpstreamStub struct {
responses []*http.Response
errors []error
requestBodies [][]byte
callCount int
onCall func(*http.Request, *queuedHTTPUpstreamStub)
}
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
if req != nil && req.Body != nil {
body, _ := io.ReadAll(req.Body)
s.requestBodies = append(s.requestBodies, body)
req.Body = io.NopCloser(bytes.NewReader(body))
} else {
s.requestBodies = append(s.requestBodies, nil)
}
idx := s.callCount
s.callCount++
if s.onCall != nil {
s.onCall(req, s)
}
var resp *http.Response
if idx < len(s.responses) {
resp = s.responses[idx]
}
var err error
if idx < len(s.errors) {
err = s.errors[idx]
}
if resp == nil && err == nil {
return nil, errors.New("unexpected upstream call")
}
return resp, err
}
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, concurrency)
}
type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
require.Equal(t, mappedModel, result.Model)
}
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req-sig-2"},
},
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 7,
Name: "acc-gemini-signature",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0])
secondReq := string(upstream.requestBodies[1])
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.NotEmpty(t, events)
require.Equal(t, "signature_error", events[0].Kind)
}
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 8,
Name: "acc-gemini-signature-failover",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-failover-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
},
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
if stub.callCount != 1 {
return
}
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account.Extra = map[string]any{
modelRateLimitsKey: map[string]any{
mappedModel: map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
}
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
require.Nil(t, result)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling)
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 2)
require.Equal(t, "signature_error", events[0].Kind)
require.Equal(t, "failover", events[1].Kind)
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"strconv"
"strings"
"sync"
"time"
@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
return d.Usage7d
}
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
// It is intentionally small so repositories can return it from a single SQL statement.
type APIKeyQuotaUsageState struct {
QuotaUsed float64
Quota float64
Key string
Status string
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
return nil
}
type quotaStateReader interface {
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
}
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
s.InvalidateAuthCacheByKey(ctx, state.Key)
}
return nil
}
// Use repository to atomically increment quota_used
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
if err != nil {

View File

@@ -0,0 +1,170 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type quotaStateRepoStub struct {
quotaBaseAPIKeyRepoStub
stateCalls int
state *APIKeyQuotaUsageState
stateErr error
}
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
s.stateCalls++
if s.stateErr != nil {
return nil, s.stateErr
}
if s.state == nil {
return nil, nil
}
out := *s.state
return &out, nil
}
type quotaStateCacheStub struct {
deleteAuthKeys []string
}
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
return 0, nil
}
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
return nil
}
type quotaBaseAPIKeyRepoStub struct {
getByIDCalls int
}
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
panic("unexpected Create call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
s.getByIDCalls++
return nil, nil
}
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
panic("unexpected Update call")
}
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
repo := &quotaStateRepoStub{
state: &APIKeyQuotaUsageState{
QuotaUsed: 12,
Quota: 10,
Key: "sk-test-quota",
Status: StatusAPIKeyQuotaExhausted,
},
}
cache := &quotaStateCacheStub{}
svc := &APIKeyService{
apiKeyRepo: repo,
cache: cache,
}
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
require.NoError(t, err)
require.Equal(t, 1, repo.stateCalls)
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
}

View File

@@ -6069,6 +6069,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
@@ -6267,6 +6283,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
break
}
flusher.Flush()
lastDataAt = time.Now()
}
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
@@ -6298,6 +6315,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing")
continue
}
flusher.Flush()
}
}

View File

@@ -0,0 +1,75 @@
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "user",
"parts": [{"text": "hello"}]
},
{
"role": "model",
"parts": [
{"text": "thinking", "thought": true, "thoughtSignature": "sig_1"},
{"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"}
]
}
],
"cachedContent": {
"parts": [{"text": "cached", "thoughtSignature": "sig_3"}]
},
"signature": "keep_me"
}`)
cleaned := CleanGeminiNativeThoughtSignatures(input)
var got map[string]any
require.NoError(t, json.Unmarshal(cleaned, &got))
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`)
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`)
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`)
require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`)
require.Contains(t, string(cleaned), `"signature":"keep_me"`)
}
func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) {
input := []byte(`{"contents":[invalid-json]}`)
cleaned := CleanGeminiNativeThoughtSignatures(input)
require.Equal(t, input, cleaned)
}
func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) {
input := map[string]any{
"thoughtSignature": "sig_root",
"signature": "keep_signature",
"nested": []any{
map[string]any{
"thoughtSignature": "sig_nested",
"signature": "keep_nested_signature",
},
},
}
got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any)
require.True(t, ok)
require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"])
require.Equal(t, "keep_signature", got["signature"])
nested, ok := got["nested"].([]any)
require.True(t, ok)
nestedMap, ok := nested[0].(map[string]any)
require.True(t, ok)
require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"])
require.Equal(t, "keep_nested_signature", nestedMap["signature"])
}

View File

@@ -1,6 +1,7 @@
package service
import (
"fmt"
"strings"
)
@@ -226,6 +227,29 @@ func normalizeCodexModel(model string) string {
return "gpt-5.1"
}
func SupportsVerbosity(model string) bool {
if !strings.HasPrefix(model, "gpt-") {
return true
}
var major, minor int
n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor)
if major > 5 {
return true
}
if major < 5 {
return false
}
// gpt-5
if n == 1 {
return true
}
return minor >= 3
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""

View File

@@ -0,0 +1,512 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Convert to Responses and forward
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", clientStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(chunks) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
}
// Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from

View File

@@ -52,6 +52,8 @@ const (
openAIWSRetryJitterRatioDefault = 0.2
openAICompactSessionSeedKey = "openai_compact_session_seed"
codexCLIVersion = "0.104.0"
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
openAICodexSnapshotPersistMinInterval = 30 * time.Second
)
// OpenAI allowed headers whitelist (for non-passthrough).
@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
nonRetryableFastFallback atomic.Int64
}
type accountWriteThrottle struct {
minInterval time.Duration
mu sync.Mutex
lastByID map[int64]time.Time
}
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
return &accountWriteThrottle{
minInterval: minInterval,
lastByID: make(map[int64]time.Time),
}
}
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
if t == nil || id <= 0 || t.minInterval <= 0 {
return true
}
t.mu.Lock()
defer t.mu.Unlock()
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
return false
}
t.lastByID[id] = now
if len(t.lastByID) > 4096 {
cutoff := now.Add(-4 * t.minInterval)
for accountID, writtenAt := range t.lastByID {
if writtenAt.Before(cutoff) {
delete(t.lastByID, accountID)
}
}
}
return true
}
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
@@ -290,6 +332,7 @@ type OpenAIGatewayService struct {
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
codexSnapshotThrottle *accountWriteThrottle
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -332,17 +375,25 @@ func NewOpenAIGatewayService(
nil,
"service.openai_gateway",
),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
svc.logOpenAIWSModeBootstrap()
return svc
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle
}
return defaultOpenAICodexSnapshotPersistThrottle
}
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
@@ -1719,6 +1770,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
markPatchSet("model", normalizedModel)
}
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
// 确保高版本模型向低版本模型映射不报错
if !SupportsVerbosity(normalizedModel) {
if text, ok := reqBody["text"].(map[string]any); ok {
delete(text, "verbosity")
}
}
}
// 规范化 reasoning.effort 参数minimal -> none与上游允许值对齐。
@@ -2954,6 +3013,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// compatErrorWriter is the signature for format-specific error writers used by
// the compat paths (Chat Completions and Anthropic Messages).
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
// handleCompatErrorResponse is the shared non-failover error handler for the
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
// tracking, secondary failover) but delegates the final error write to the
// format-specific writer function.
func (s *OpenAIGatewayService) handleCompatErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
writeError compatErrorWriter,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes — if the account does not handle this status,
// return a generic error without exposing upstream details.
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Map status code to error type and write response
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage
@@ -4071,11 +4244,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if len(updates) == 0 && resetAt == nil {
return
}
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
if !shouldPersistUpdates && resetAt == nil {
return
}
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if len(updates) > 0 {
if shouldPersistUpdates {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {

View File

@@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 2),
rateLimitCh: make(chan time.Time, 2),
}
svc := &OpenAIGatewayService{
accountRepo: repo,
codexSnapshotThrottle: newAccountWriteThrottle(time.Hour),
}
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: ptrFloat64WS(94),
PrimaryResetAfterSeconds: ptrIntWS(3600),
PrimaryWindowMinutes: ptrIntWS(10080),
SecondaryUsedPercent: ptrFloat64WS(22),
SecondaryResetAfterSeconds: ptrIntWS(1200),
SecondaryWindowMinutes: ptrIntWS(300),
}
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
select {
case <-repo.updateExtraCh:
case <-time.After(2 * time.Second):
t.Fatal("等待第一次 codex 快照落库超时")
}
select {
case updates := <-repo.updateExtraCh:
t.Fatalf("unexpected second codex snapshot write: %v", updates)
case <-time.After(200 * time.Millisecond):
}
}
func ptrFloat64WS(v float64) *float64 { return &v }
func ptrIntWS(v int) *int { return &v }

View File

@@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})), true
case "group_rate_limit_ratio":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
if availability.Group == nil || availability.Group.TotalAccounts <= 0 {
return 0, true
}
return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true
case "account_error_ratio":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
total := int64(len(availability.Accounts))
if total <= 0 {
return 0, true
}
errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})
return (float64(errorCount) / float64(total)) * 100, true
case "overload_account_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.IsOverloaded
})), true
}
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{

View File

@@ -7,6 +7,7 @@ import (
type OpsRepository interface {
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)

View File

@@ -7,6 +7,8 @@ import (
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
type opsRepoMock struct {
InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
@@ -14,9 +16,19 @@ type opsRepoMock struct {
}
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
if m.InsertErrorLogFn != nil {
return m.InsertErrorLogFn(ctx, input)
}
return 0, nil
}
func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
if m.BatchInsertErrorLogsFn != nil {
return m.BatchInsertErrorLogsFn(ctx, inputs)
}
return int64(len(inputs)), nil
}
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
}

View File

@@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
}
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
if entry == nil {
prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody)
if err != nil {
log.Printf("[Ops] RecordError prepare failed: %v", err)
return err
}
if !ok {
return nil
}
if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil {
// Never bubble up to gateway; best-effort logging.
log.Printf("[Ops] RecordError failed: %v", err)
return err
}
return nil
}
func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error {
if len(entries) == 0 {
return nil
}
prepared := make([]*OpsInsertErrorLogInput, 0, len(entries))
for _, entry := range entries {
item, ok, err := s.prepareErrorLogInput(ctx, entry, nil)
if err != nil {
log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err)
continue
}
if ok {
prepared = append(prepared, item)
}
}
if len(prepared) == 0 {
return nil
}
if len(prepared) == 1 {
_, err := s.opsRepo.InsertErrorLog(ctx, prepared[0])
if err != nil {
log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err)
}
return err
}
if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil {
log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err)
var firstErr error
for _, entry := range prepared {
if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil {
log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr)
if firstErr == nil {
firstErr = insertErr
}
}
}
return firstErr
}
return nil
}
func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) {
if entry == nil {
return nil, false, nil
}
if !s.IsMonitoringEnabled(ctx) {
return nil
return nil, false, nil
}
if s.opsRepo == nil {
return nil
return nil, false, nil
}
// Ensure timestamps are always populated.
@@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
}
}
// Sanitize + serialize upstream error events list.
if len(entry.UpstreamErrors) > 0 {
const maxEvents = 32
events := entry.UpstreamErrors
if len(events) > maxEvents {
events = events[len(events)-maxEvents:]
if err := sanitizeOpsUpstreamErrors(entry); err != nil {
return nil, false, err
}
return entry, true, nil
}
func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
if entry == nil || len(entry.UpstreamErrors) == 0 {
return nil
}
const maxEvents = 32
events := entry.UpstreamErrors
if len(events) > maxEvents {
events = events[len(events)-maxEvents:]
}
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
for _, ev := range events {
if ev == nil {
continue
}
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
if out.AccountID < 0 {
out.AccountID = 0
}
if out.UpstreamStatusCode < 0 {
out.UpstreamStatusCode = 0
}
if out.AtUnixMs < 0 {
out.AtUnixMs = 0
}
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
for _, ev := range events {
if ev == nil {
continue
}
out := *ev
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
msg = truncateString(msg, 2048)
out.Message = msg
out.Platform = strings.TrimSpace(out.Platform)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
detail := strings.TrimSpace(out.Detail)
if detail != "" {
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
out.Detail = sanitizedDetail
} else {
out.Detail = ""
}
if out.AccountID < 0 {
out.AccountID = 0
}
if out.UpstreamStatusCode < 0 {
out.UpstreamStatusCode = 0
}
if out.AtUnixMs < 0 {
out.AtUnixMs = 0
}
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
msg = truncateString(msg, 2048)
out.Message = msg
detail := strings.TrimSpace(out.Detail)
if detail != "" {
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
out.Detail = sanitizedDetail
} else {
out.Detail = ""
}
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitized != "" {
out.UpstreamRequestBody = sanitized
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
out.Kind = out.Kind + ":request_body_truncated"
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitizedBody != "" {
out.UpstreamRequestBody = sanitizedBody
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
} else {
out.UpstreamRequestBody = ""
out.Kind = out.Kind + ":request_body_truncated"
}
} else {
out.UpstreamRequestBody = ""
}
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
}
evCopy := out
sanitized = append(sanitized, &evCopy)
}
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
entry.UpstreamErrors = nil
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
}
evCopy := out
sanitized = append(sanitized, &evCopy)
}
if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil {
// Never bubble up to gateway; best-effort logging.
log.Printf("[Ops] RecordError failed: %v", err)
return err
}
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
entry.UpstreamErrors = nil
return nil
}

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) {
t.Parallel()
var captured []*OpsInsertErrorLogInput
repo := &opsRepoMock{
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
captured = append(captured, inputs...)
return int64(len(inputs)), nil
},
}
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
msg := " upstream failed: https://example.com?access_token=secret-value "
detail := `{"authorization":"Bearer secret-token"}`
entries := []*OpsInsertErrorLogInput{
{
ErrorBody: `{"error":"bad","access_token":"secret"}`,
UpstreamStatusCode: intPtr(-10),
UpstreamErrorMessage: strPtr(msg),
UpstreamErrorDetail: strPtr(detail),
UpstreamErrors: []*OpsUpstreamErrorEvent{
{
AccountID: -2,
UpstreamStatusCode: 429,
Message: " token leaked ",
Detail: `{"refresh_token":"secret"}`,
UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`,
},
},
},
{
ErrorPhase: "upstream",
ErrorType: "upstream_error",
CreatedAt: time.Now().UTC(),
},
}
require.NoError(t, svc.RecordErrorBatch(context.Background(), entries))
require.Len(t, captured, 2)
first := captured[0]
require.Equal(t, "internal", first.ErrorPhase)
require.Equal(t, "api_error", first.ErrorType)
require.Nil(t, first.UpstreamStatusCode)
require.NotNil(t, first.UpstreamErrorMessage)
require.NotContains(t, *first.UpstreamErrorMessage, "secret-value")
require.Contains(t, *first.UpstreamErrorMessage, "access_token=***")
require.NotNil(t, first.UpstreamErrorDetail)
require.NotContains(t, *first.UpstreamErrorDetail, "secret-token")
require.NotContains(t, first.ErrorBody, "secret")
require.Nil(t, first.UpstreamErrors)
require.NotNil(t, first.UpstreamErrorsJSON)
require.NotContains(t, *first.UpstreamErrorsJSON, "secret")
require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]")
second := captured[1]
require.Equal(t, "upstream", second.ErrorPhase)
require.Equal(t, "upstream_error", second.ErrorType)
require.False(t, second.CreatedAt.IsZero())
}
func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) {
t.Parallel()
var (
batchCalls int
singleCalls int
)
repo := &opsRepoMock{
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
batchCalls++
return 0, errors.New("batch failed")
},
InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
singleCalls++
return int64(singleCalls), nil
},
}
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{
{ErrorMessage: "first"},
{ErrorMessage: "second"},
})
require.NoError(t, err)
require.Equal(t, 1, batchCalls)
require.Equal(t, 2, singleCalls)
}
func strPtr(v string) *string {
return &v
}

View File

@@ -0,0 +1,166 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage
// 其余方法继承 userSubRepoNooppanic
type resetQuotaUserSubRepoStub struct {
userSubRepoNoop
sub *UserSubscription
resetDailyCalled bool
resetWeeklyCalled bool
resetDailyErr error
resetWeeklyErr error
}
func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
if r.sub == nil || r.sub.ID != id {
return nil, ErrSubscriptionNotFound
}
cp := *r.sub
return &cp, nil
}
func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error {
r.resetDailyCalled = true
if r.resetDailyErr == nil && r.sub != nil {
r.sub.DailyUsageUSD = 0
r.sub.DailyWindowStart = &windowStart
}
return r.resetDailyErr
}
func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error {
r.resetWeeklyCalled = true
return r.resetWeeklyErr
}
func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService {
return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil)
}
func TestAdminResetQuota_ResetBoth(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 1, true, true)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
}
func TestAdminResetQuota_ResetDailyOnly(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 2, true, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage")
}
func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 3, false, true)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage")
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
}
func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 7, false, false)
require.ErrorIs(t, err, ErrInvalidInput)
require.False(t, stub.resetDailyCalled)
require.False(t, stub.resetWeeklyCalled)
}
func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{sub: nil}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 999, true, true)
require.ErrorIs(t, err, ErrSubscriptionNotFound)
require.False(t, stub.resetDailyCalled)
require.False(t, stub.resetWeeklyCalled)
}
func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) {
dbErr := errors.New("db error")
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20},
resetDailyErr: dbErr,
}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 4, true, true)
require.ErrorIs(t, err, dbErr)
require.True(t, stub.resetDailyCalled)
require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly")
}
func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) {
dbErr := errors.New("db error")
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20},
resetWeeklyErr: dbErr,
}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 5, false, true)
require.ErrorIs(t, err, dbErr)
require.True(t, stub.resetWeeklyCalled)
}
func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{
ID: 6,
UserID: 10,
GroupID: 20,
DailyUsageUSD: 99.9,
},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 6, true, false)
require.NoError(t, err)
// ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零,
// 服务应返回第二次 GetByID 的刷新值而非初始的 99.9
require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量")
require.True(t, stub.resetDailyCalled)
}

View File

@@ -31,6 +31,7 @@ var (
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
@@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
}
// AdminResetQuota manually resets the daily and/or weekly usage windows.
// Uses startOfDay(now) as the new window start, matching automatic resets.
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
if !resetDaily && !resetWeekly {
return nil, ErrInvalidInput
}
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, err
}
windowStart := startOfDay(time.Now())
if resetDaily {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
if resetWeekly {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
// Invalidate caches, same as CheckAndResetWindows
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
// Return the refreshed subscription from DB
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间