diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 91437ba8..460bd05d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -279,6 +279,9 @@ type GatewayConfig struct { // Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用 AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"` + // 默认重试用完后,额外使用 Antigravity 账号重试的最大次数(0 表示禁用) + AntigravityExtraRetries int `mapstructure:"antigravity_extra_retries"` + // Scheduling: 账号调度相关配置 Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` @@ -883,6 +886,7 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) + viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 6900fa55..361cd8b5 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -39,6 +39,7 @@ type GatewayHandler struct { concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int + antigravityExtraRetries int } // NewGatewayHandler creates a new GatewayHandler @@ -57,6 +58,7 @@ func NewGatewayHandler( pingInterval := time.Duration(0) maxAccountSwitches := 10 maxAccountSwitchesGemini := 3 + antigravityExtraRetries := 10 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second if cfg.Gateway.MaxAccountSwitches > 0 { @@ -65,6 +67,7 @@ func NewGatewayHandler( if cfg.Gateway.MaxAccountSwitchesGemini > 0 { maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini } + antigravityExtraRetries = cfg.Gateway.AntigravityExtraRetries } return &GatewayHandler{ gatewayService: gatewayService, @@ -78,6 +81,7 @@ func NewGatewayHandler( concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, + antigravityExtraRetries: antigravityExtraRetries, } } @@ -234,6 +238,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if platform == service.PlatformGemini { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 + antigravityExtraCount := 0 failedAccountIDs := make(map[int64]struct{}) var lastFailoverErr *service.UpstreamFailoverError var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 @@ -255,6 +260,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) + // 额外重试阶段:跳过非 Antigravity 账号 + if switchCount >= maxAccountSwitches && account.Platform != service.PlatformAntigravity { + failedAccountIDs[account.ID] = struct{}{} + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + continue + } + // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) @@ -345,8 +359,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) - return + // 默认重试用完,进入 Antigravity 额外重试 + antigravityExtraCount++ + if antigravityExtraCount > h.antigravityExtraRetries { + h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) + return + } + log.Printf("Account %d: antigravity extra retry %d/%d", account.ID, antigravityExtraCount, h.antigravityExtraRetries) + if !sleepFixedDelay(c.Request.Context(), antigravityExtraRetryDelay) { + return + } + continue } switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) @@ -399,6 +422,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 + antigravityExtraCount := 0 failedAccountIDs := make(map[int64]struct{}) var lastFailoverErr *service.UpstreamFailoverError retryWithFallback := false @@ -422,6 +446,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) + // 额外重试阶段:跳过非 Antigravity 账号 + if switchCount >= maxAccountSwitches && account.Platform != service.PlatformAntigravity { + failedAccountIDs[account.ID] = struct{}{} + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + continue + } + // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) @@ -545,8 +578,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) - return + // 默认重试用完,进入 Antigravity 额外重试 + antigravityExtraCount++ + if antigravityExtraCount > h.antigravityExtraRetries { + h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) + return + } + log.Printf("Account %d: antigravity extra retry %d/%d", account.ID, antigravityExtraCount, h.antigravityExtraRetries) + if !sleepFixedDelay(c.Request.Context(), antigravityExtraRetryDelay) { + return + } + continue } switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) @@ -838,6 +880,21 @@ func sleepFailoverDelay(ctx context.Context, switchCount int) bool { } } +const antigravityExtraRetryDelay = 500 * time.Millisecond + +// sleepFixedDelay 固定延时等待,返回 false 表示 context 已取消。 +func sleepFixedDelay(ctx context.Context, delay time.Duration) bool { + if delay <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(delay): + return true + } +} + func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody diff --git a/backend/internal/handler/gateway_handler_extra_retry_test.go b/backend/internal/handler/gateway_handler_extra_retry_test.go new file mode 100644 index 00000000..a0777941 --- /dev/null +++ b/backend/internal/handler/gateway_handler_extra_retry_test.go @@ -0,0 +1,417 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- sleepFixedDelay --- + +func TestSleepFixedDelay_ZeroDelay(t *testing.T) { + got := sleepFixedDelay(context.Background(), 0) + require.True(t, got, "zero delay should return true immediately") +} + +func TestSleepFixedDelay_NegativeDelay(t *testing.T) { + got := sleepFixedDelay(context.Background(), -1*time.Second) + require.True(t, got, "negative delay should return true immediately") +} + +func TestSleepFixedDelay_NormalDelay(t *testing.T) { + start := time.Now() + got := sleepFixedDelay(context.Background(), 50*time.Millisecond) + elapsed := time.Since(start) + require.True(t, got, "normal delay should return true") + require.GreaterOrEqual(t, elapsed, 40*time.Millisecond, "should sleep at least ~50ms") +} + +func TestSleepFixedDelay_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + got := sleepFixedDelay(ctx, 10*time.Second) + require.False(t, got, "cancelled context should return false") +} + +func TestSleepFixedDelay_ContextTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + got := sleepFixedDelay(ctx, 5*time.Second) + require.False(t, got, "context timeout should return false before delay completes") +} + +// --- antigravityExtraRetryDelay constant --- + +func TestAntigravityExtraRetryDelayValue(t *testing.T) { + require.Equal(t, 500*time.Millisecond, antigravityExtraRetryDelay) +} + +// --- NewGatewayHandler antigravityExtraRetries field --- + +func TestNewGatewayHandler_AntigravityExtraRetries_Default(t *testing.T) { + h := NewGatewayHandler(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + require.Equal(t, 10, h.antigravityExtraRetries, "default should be 10 when cfg is nil") +} + +func TestNewGatewayHandler_AntigravityExtraRetries_FromConfig(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + AntigravityExtraRetries: 5, + }, + } + h := NewGatewayHandler(nil, nil, nil, nil, nil, nil, nil, nil, nil, cfg) + require.Equal(t, 5, h.antigravityExtraRetries, "should use config value") +} + +func TestNewGatewayHandler_AntigravityExtraRetries_ZeroDisables(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + AntigravityExtraRetries: 0, + }, + } + h := NewGatewayHandler(nil, nil, nil, nil, nil, nil, nil, nil, nil, cfg) + require.Equal(t, 0, h.antigravityExtraRetries, "zero should disable extra retries") +} + +// --- handleFailoverAllAccountsExhausted (renamed: using handleFailoverExhausted) --- +// We test the error response format helpers that the extra retry path uses. + +func TestHandleFailoverExhausted_JSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + h := &GatewayHandler{} + failoverErr := &service.UpstreamFailoverError{StatusCode: 429} + h.handleFailoverExhausted(c, failoverErr, service.PlatformAntigravity, false) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + errObj, ok := body["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "rate_limit_error", errObj["type"]) +} + +func TestHandleFailoverExhaustedSimple_JSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + h := &GatewayHandler{} + h.handleFailoverExhaustedSimple(c, 502, false) + + require.Equal(t, http.StatusBadGateway, rec.Code) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + errObj, ok := body["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errObj["type"]) +} + +// --- Extra retry platform filter logic --- + +func TestExtraRetryPlatformFilter(t *testing.T) { + tests := []struct { + name string + switchCount int + maxAccountSwitch int + platform string + expectSkip bool + }{ + { + name: "default_retry_phase_antigravity_not_skipped", + switchCount: 1, + maxAccountSwitch: 3, + platform: service.PlatformAntigravity, + expectSkip: false, + }, + { + name: "default_retry_phase_gemini_not_skipped", + switchCount: 1, + maxAccountSwitch: 3, + platform: service.PlatformGemini, + expectSkip: false, + }, + { + name: "extra_retry_phase_antigravity_not_skipped", + switchCount: 3, + maxAccountSwitch: 3, + platform: service.PlatformAntigravity, + expectSkip: false, + }, + { + name: "extra_retry_phase_gemini_skipped", + switchCount: 3, + maxAccountSwitch: 3, + platform: service.PlatformGemini, + expectSkip: true, + }, + { + name: "extra_retry_phase_anthropic_skipped", + switchCount: 3, + maxAccountSwitch: 3, + platform: service.PlatformAnthropic, + expectSkip: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Replicate the filter condition from the handler + shouldSkip := tt.switchCount >= tt.maxAccountSwitch && tt.platform != service.PlatformAntigravity + require.Equal(t, tt.expectSkip, shouldSkip) + }) + } +} + +// --- Extra retry counter logic --- + +func TestExtraRetryCounterExhaustion(t *testing.T) { + tests := []struct { + name string + maxExtraRetries int + currentExtraCount int + expectExhausted bool + }{ + { + name: "first_extra_retry", + maxExtraRetries: 10, + currentExtraCount: 1, + expectExhausted: false, + }, + { + name: "at_limit", + maxExtraRetries: 10, + currentExtraCount: 10, + expectExhausted: false, + }, + { + name: "exceeds_limit", + maxExtraRetries: 10, + currentExtraCount: 11, + expectExhausted: true, + }, + { + name: "zero_disables_extra_retry", + maxExtraRetries: 0, + currentExtraCount: 1, + expectExhausted: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Replicate the exhaustion condition: antigravityExtraCount > h.antigravityExtraRetries + exhausted := tt.currentExtraCount > tt.maxExtraRetries + require.Equal(t, tt.expectExhausted, exhausted) + }) + } +} + +// --- mapUpstreamError (used by handleFailoverExhausted) --- + +func TestMapUpstreamError(t *testing.T) { + h := &GatewayHandler{} + tests := []struct { + name string + statusCode int + expectedStatus int + expectedType string + }{ + {"429", 429, http.StatusTooManyRequests, "rate_limit_error"}, + {"529", 529, http.StatusServiceUnavailable, "overloaded_error"}, + {"500", 500, http.StatusBadGateway, "upstream_error"}, + {"502", 502, http.StatusBadGateway, "upstream_error"}, + {"503", 503, http.StatusBadGateway, "upstream_error"}, + {"504", 504, http.StatusBadGateway, "upstream_error"}, + {"401", 401, http.StatusBadGateway, "upstream_error"}, + {"403", 403, http.StatusBadGateway, "upstream_error"}, + {"unknown", 418, http.StatusBadGateway, "upstream_error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + status, errType, _ := h.mapUpstreamError(tt.statusCode) + require.Equal(t, tt.expectedStatus, status) + require.Equal(t, tt.expectedType, errType) + }) + } +} + +// --- Gemini native path: handleGeminiFailoverExhausted --- + +func TestHandleGeminiFailoverExhausted_NilError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + h := &GatewayHandler{} + h.handleGeminiFailoverExhausted(c, nil) + + require.Equal(t, http.StatusBadGateway, rec.Code) + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + errObj, ok := body["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "Upstream request failed", errObj["message"]) +} + +func TestHandleGeminiFailoverExhausted_429(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + h := &GatewayHandler{} + failoverErr := &service.UpstreamFailoverError{StatusCode: 429} + h.handleGeminiFailoverExhausted(c, failoverErr) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +// --- handleStreamingAwareError streaming mode --- + +func TestHandleStreamingAwareError_StreamStarted(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // Simulate stream already started: set content type and write initial data + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.WriteHeaderNow() + + h := &GatewayHandler{} + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "test error", true) + + body := rec.Body.String() + require.Contains(t, body, "rate_limit_error") + require.Contains(t, body, "test error") + require.Contains(t, body, "data: ") +} + +func TestHandleStreamingAwareError_NotStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + h := &GatewayHandler{} + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "no model", false) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + errObj, ok := body["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "api_error", errObj["type"]) + require.Equal(t, "no model", errObj["message"]) +} + +// --- Integration: extra retry flow simulation --- + +func TestExtraRetryFlowSimulation(t *testing.T) { + // Simulate the full extra retry flow logic + maxAccountSwitches := 3 + maxExtraRetries := 2 + switchCount := 0 + antigravityExtraCount := 0 + + type attempt struct { + platform string + isFailover bool + } + + // Simulate: 3 default retries (all fail), then 2 extra retries (all fail), then exhausted + attempts := []attempt{ + {service.PlatformAntigravity, true}, // switchCount 0 -> 1 + {service.PlatformGemini, true}, // switchCount 1 -> 2 + {service.PlatformAntigravity, true}, // switchCount 2 -> 3 (reaches max) + {service.PlatformAntigravity, true}, // extra retry 1 + {service.PlatformAntigravity, true}, // extra retry 2 + {service.PlatformAntigravity, true}, // extra retry 3 -> exhausted + } + + var exhausted bool + var skipped int + + for _, a := range attempts { + if exhausted { + break + } + + // Extra retry phase: skip non-Antigravity + if switchCount >= maxAccountSwitches && a.platform != service.PlatformAntigravity { + skipped++ + continue + } + + if a.isFailover { + if switchCount >= maxAccountSwitches { + antigravityExtraCount++ + if antigravityExtraCount > maxExtraRetries { + exhausted = true + continue + } + // extra retry delay + continue + continue + } + switchCount++ + } + } + + require.Equal(t, 3, switchCount, "should have 3 default retries") + require.Equal(t, 3, antigravityExtraCount, "counter incremented 3 times") + require.True(t, exhausted, "should be exhausted after exceeding max extra retries") + require.Equal(t, 0, skipped, "no non-antigravity accounts in this simulation") +} + +func TestExtraRetryFlowSimulation_SkipsNonAntigravity(t *testing.T) { + maxAccountSwitches := 2 + switchCount := 2 // already past default retries + antigravityExtraCount := 0 + maxExtraRetries := 5 + + type accountSelection struct { + platform string + } + + selections := []accountSelection{ + {service.PlatformGemini}, // should be skipped + {service.PlatformAnthropic}, // should be skipped + {service.PlatformAntigravity}, // should be attempted + } + + var skippedCount int + var attemptedCount int + + for _, sel := range selections { + if switchCount >= maxAccountSwitches && sel.platform != service.PlatformAntigravity { + skippedCount++ + continue + } + // Simulate failover + antigravityExtraCount++ + if antigravityExtraCount > maxExtraRetries { + break + } + attemptedCount++ + } + + require.Equal(t, 2, skippedCount, "gemini and anthropic accounts should be skipped") + require.Equal(t, 1, attemptedCount, "only antigravity account should be attempted") + require.Equal(t, 1, antigravityExtraCount) +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d5149f22..5a576ab0 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -323,6 +323,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 + antigravityExtraCount := 0 failedAccountIDs := make(map[int64]struct{}) var lastFailoverErr *service.UpstreamFailoverError var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 @@ -340,6 +341,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) + // 额外重试阶段:跳过非 Antigravity 账号 + if switchCount >= maxAccountSwitches && account.Platform != service.PlatformAntigravity { + failedAccountIDs[account.ID] = struct{}{} + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + continue + } + // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { @@ -424,15 +434,23 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { - lastFailoverErr = failoverErr - h.handleGeminiFailoverExhausted(c, lastFailoverErr) - return + // 默认重试用完,进入 Antigravity 额外重试 + antigravityExtraCount++ + if antigravityExtraCount > h.antigravityExtraRetries { + h.handleGeminiFailoverExhausted(c, failoverErr) + return + } + log.Printf("Gemini account %d: antigravity extra retry %d/%d", account.ID, antigravityExtraCount, h.antigravityExtraRetries) + if !sleepFixedDelay(c.Request.Context(), antigravityExtraRetryDelay) { + return + } + continue } - lastFailoverErr = failoverErr switchCount++ log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) if account.Platform == service.PlatformAntigravity {