mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-25 17:14:45 +08:00
feat: Antigravity extra failover retries after default retries exhausted
When default failover retries are exhausted, continue retrying with Antigravity accounts only (up to 10 times, configurable via GATEWAY_ANTIGRAVITY_EXTRA_RETRIES). Each extra retry uses a fixed 500ms delay. Non-Antigravity accounts are skipped during the extra retry phase. Applied to all three endpoints: Gemini compat, Claude, and Gemini native API paths.
This commit is contained in:
@@ -279,6 +279,9 @@ type GatewayConfig struct {
|
|||||||
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
|
// Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
|
||||||
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
|
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
|
||||||
|
|
||||||
|
// 默认重试用完后,额外使用 Antigravity 账号重试的最大次数(0 表示禁用)
|
||||||
|
AntigravityExtraRetries int `mapstructure:"antigravity_extra_retries"`
|
||||||
|
|
||||||
// Scheduling: 账号调度相关配置
|
// Scheduling: 账号调度相关配置
|
||||||
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
|
||||||
|
|
||||||
@@ -883,6 +886,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.max_account_switches", 10)
|
viper.SetDefault("gateway.max_account_switches", 10)
|
||||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||||
|
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type GatewayHandler struct {
|
|||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
maxAccountSwitchesGemini int
|
maxAccountSwitchesGemini int
|
||||||
|
antigravityExtraRetries int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayHandler creates a new GatewayHandler
|
// NewGatewayHandler creates a new GatewayHandler
|
||||||
@@ -57,6 +58,7 @@ func NewGatewayHandler(
|
|||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
maxAccountSwitches := 10
|
maxAccountSwitches := 10
|
||||||
maxAccountSwitchesGemini := 3
|
maxAccountSwitchesGemini := 3
|
||||||
|
antigravityExtraRetries := 10
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||||
@@ -65,6 +67,7 @@ func NewGatewayHandler(
|
|||||||
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
|
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
|
||||||
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
||||||
}
|
}
|
||||||
|
antigravityExtraRetries = cfg.Gateway.AntigravityExtraRetries
|
||||||
}
|
}
|
||||||
return &GatewayHandler{
|
return &GatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
@@ -78,6 +81,7 @@ func NewGatewayHandler(
|
|||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||||
|
antigravityExtraRetries: antigravityExtraRetries,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,6 +238,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
|
antigravityExtraCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||||
@@ -255,6 +260,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
|
// 额外重试阶段:跳过非 Antigravity 账号
|
||||||
|
if switchCount >= maxAccountSwitches && account.Platform != service.PlatformAntigravity {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
if account.IsInterceptWarmupEnabled() {
|
if account.IsInterceptWarmupEnabled() {
|
||||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||||
@@ -345,9 +359,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
|
// 默认重试用完,进入 Antigravity 额外重试
|
||||||
|
antigravityExtraCount++
|
||||||
|
if antigravityExtraCount > h.antigravityExtraRetries {
|
||||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Printf("Account %d: antigravity extra retry %d/%d", account.ID, antigravityExtraCount, h.antigravityExtraRetries)
|
||||||
|
if !sleepFixedDelay(c.Request.Context(), antigravityExtraRetryDelay) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
@@ -399,6 +422,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
|
antigravityExtraCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
retryWithFallback := false
|
retryWithFallback := false
|
||||||
@@ -422,6 +446,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
|
// 额外重试阶段:跳过非 Antigravity 账号
|
||||||
|
if switchCount >= maxAccountSwitches && account.Platform != service.PlatformAntigravity {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
if account.IsInterceptWarmupEnabled() {
|
if account.IsInterceptWarmupEnabled() {
|
||||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||||
@@ -545,9 +578,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
|
// 默认重试用完,进入 Antigravity 额外重试
|
||||||
|
antigravityExtraCount++
|
||||||
|
if antigravityExtraCount > h.antigravityExtraRetries {
|
||||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Printf("Account %d: antigravity extra retry %d/%d", account.ID, antigravityExtraCount, h.antigravityExtraRetries)
|
||||||
|
if !sleepFixedDelay(c.Request.Context(), antigravityExtraRetryDelay) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
@@ -838,6 +880,21 @@ func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const antigravityExtraRetryDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
|
// sleepFixedDelay 固定延时等待,返回 false 表示 context 已取消。
|
||||||
|
func sleepFixedDelay(ctx context.Context, delay time.Duration) bool {
|
||||||
|
if delay <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(delay):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||||
statusCode := failoverErr.StatusCode
|
statusCode := failoverErr.StatusCode
|
||||||
responseBody := failoverErr.ResponseBody
|
responseBody := failoverErr.ResponseBody
|
||||||
|
|||||||
417
backend/internal/handler/gateway_handler_extra_retry_test.go
Normal file
417
backend/internal/handler/gateway_handler_extra_retry_test.go
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- sleepFixedDelay ---
|
||||||
|
|
||||||
|
func TestSleepFixedDelay_ZeroDelay(t *testing.T) {
|
||||||
|
got := sleepFixedDelay(context.Background(), 0)
|
||||||
|
require.True(t, got, "zero delay should return true immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSleepFixedDelay_NegativeDelay(t *testing.T) {
|
||||||
|
got := sleepFixedDelay(context.Background(), -1*time.Second)
|
||||||
|
require.True(t, got, "negative delay should return true immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSleepFixedDelay_NormalDelay(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
got := sleepFixedDelay(context.Background(), 50*time.Millisecond)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.True(t, got, "normal delay should return true")
|
||||||
|
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond, "should sleep at least ~50ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSleepFixedDelay_ContextCancelled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // cancel immediately
|
||||||
|
got := sleepFixedDelay(ctx, 10*time.Second)
|
||||||
|
require.False(t, got, "cancelled context should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSleepFixedDelay_ContextTimeout(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
got := sleepFixedDelay(ctx, 5*time.Second)
|
||||||
|
require.False(t, got, "context timeout should return false before delay completes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- antigravityExtraRetryDelay constant ---
|
||||||
|
|
||||||
|
func TestAntigravityExtraRetryDelayValue(t *testing.T) {
|
||||||
|
require.Equal(t, 500*time.Millisecond, antigravityExtraRetryDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- NewGatewayHandler antigravityExtraRetries field ---
|
||||||
|
|
||||||
|
func TestNewGatewayHandler_AntigravityExtraRetries_Default(t *testing.T) {
|
||||||
|
h := NewGatewayHandler(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
require.Equal(t, 10, h.antigravityExtraRetries, "default should be 10 when cfg is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGatewayHandler_AntigravityExtraRetries_FromConfig(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
AntigravityExtraRetries: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := NewGatewayHandler(nil, nil, nil, nil, nil, nil, nil, nil, nil, cfg)
|
||||||
|
require.Equal(t, 5, h.antigravityExtraRetries, "should use config value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGatewayHandler_AntigravityExtraRetries_ZeroDisables(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
AntigravityExtraRetries: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := NewGatewayHandler(nil, nil, nil, nil, nil, nil, nil, nil, nil, cfg)
|
||||||
|
require.Equal(t, 0, h.antigravityExtraRetries, "zero should disable extra retries")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- handleFailoverAllAccountsExhausted (renamed: using handleFailoverExhausted) ---
|
||||||
|
// We test the error response format helpers that the extra retry path uses.
|
||||||
|
|
||||||
|
func TestHandleFailoverExhausted_JSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
failoverErr := &service.UpstreamFailoverError{StatusCode: 429}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformAntigravity, false)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errObj, ok := body["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "rate_limit_error", errObj["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleFailoverExhaustedSimple_JSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhaustedSimple(c, 502, false)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errObj, ok := body["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errObj["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Extra retry platform filter logic ---
|
||||||
|
|
||||||
|
func TestExtraRetryPlatformFilter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
switchCount int
|
||||||
|
maxAccountSwitch int
|
||||||
|
platform string
|
||||||
|
expectSkip bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default_retry_phase_antigravity_not_skipped",
|
||||||
|
switchCount: 1,
|
||||||
|
maxAccountSwitch: 3,
|
||||||
|
platform: service.PlatformAntigravity,
|
||||||
|
expectSkip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default_retry_phase_gemini_not_skipped",
|
||||||
|
switchCount: 1,
|
||||||
|
maxAccountSwitch: 3,
|
||||||
|
platform: service.PlatformGemini,
|
||||||
|
expectSkip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra_retry_phase_antigravity_not_skipped",
|
||||||
|
switchCount: 3,
|
||||||
|
maxAccountSwitch: 3,
|
||||||
|
platform: service.PlatformAntigravity,
|
||||||
|
expectSkip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra_retry_phase_gemini_skipped",
|
||||||
|
switchCount: 3,
|
||||||
|
maxAccountSwitch: 3,
|
||||||
|
platform: service.PlatformGemini,
|
||||||
|
expectSkip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra_retry_phase_anthropic_skipped",
|
||||||
|
switchCount: 3,
|
||||||
|
maxAccountSwitch: 3,
|
||||||
|
platform: service.PlatformAnthropic,
|
||||||
|
expectSkip: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Replicate the filter condition from the handler
|
||||||
|
shouldSkip := tt.switchCount >= tt.maxAccountSwitch && tt.platform != service.PlatformAntigravity
|
||||||
|
require.Equal(t, tt.expectSkip, shouldSkip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Extra retry counter logic ---
|
||||||
|
|
||||||
|
func TestExtraRetryCounterExhaustion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
maxExtraRetries int
|
||||||
|
currentExtraCount int
|
||||||
|
expectExhausted bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "first_extra_retry",
|
||||||
|
maxExtraRetries: 10,
|
||||||
|
currentExtraCount: 1,
|
||||||
|
expectExhausted: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "at_limit",
|
||||||
|
maxExtraRetries: 10,
|
||||||
|
currentExtraCount: 10,
|
||||||
|
expectExhausted: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exceeds_limit",
|
||||||
|
maxExtraRetries: 10,
|
||||||
|
currentExtraCount: 11,
|
||||||
|
expectExhausted: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero_disables_extra_retry",
|
||||||
|
maxExtraRetries: 0,
|
||||||
|
currentExtraCount: 1,
|
||||||
|
expectExhausted: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Replicate the exhaustion condition: antigravityExtraCount > h.antigravityExtraRetries
|
||||||
|
exhausted := tt.currentExtraCount > tt.maxExtraRetries
|
||||||
|
require.Equal(t, tt.expectExhausted, exhausted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- mapUpstreamError (used by handleFailoverExhausted) ---
|
||||||
|
|
||||||
|
func TestMapUpstreamError(t *testing.T) {
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
expectedStatus int
|
||||||
|
expectedType string
|
||||||
|
}{
|
||||||
|
{"429", 429, http.StatusTooManyRequests, "rate_limit_error"},
|
||||||
|
{"529", 529, http.StatusServiceUnavailable, "overloaded_error"},
|
||||||
|
{"500", 500, http.StatusBadGateway, "upstream_error"},
|
||||||
|
{"502", 502, http.StatusBadGateway, "upstream_error"},
|
||||||
|
{"503", 503, http.StatusBadGateway, "upstream_error"},
|
||||||
|
{"504", 504, http.StatusBadGateway, "upstream_error"},
|
||||||
|
{"401", 401, http.StatusBadGateway, "upstream_error"},
|
||||||
|
{"403", 403, http.StatusBadGateway, "upstream_error"},
|
||||||
|
{"unknown", 418, http.StatusBadGateway, "upstream_error"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
status, errType, _ := h.mapUpstreamError(tt.statusCode)
|
||||||
|
require.Equal(t, tt.expectedStatus, status)
|
||||||
|
require.Equal(t, tt.expectedType, errType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Gemini native path: handleGeminiFailoverExhausted ---
|
||||||
|
|
||||||
|
func TestHandleGeminiFailoverExhausted_NilError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleGeminiFailoverExhausted(c, nil)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||||
|
var body map[string]any
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errObj, ok := body["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "Upstream request failed", errObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleGeminiFailoverExhausted_429(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
failoverErr := &service.UpstreamFailoverError{StatusCode: 429}
|
||||||
|
h.handleGeminiFailoverExhausted(c, failoverErr)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- handleStreamingAwareError streaming mode ---
|
||||||
|
|
||||||
|
func TestHandleStreamingAwareError_StreamStarted(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
// Simulate stream already started: set content type and write initial data
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.WriteHeaderNow()
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "test error", true)
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "rate_limit_error")
|
||||||
|
require.Contains(t, body, "test error")
|
||||||
|
require.Contains(t, body, "data: ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamingAwareError_NotStreaming(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "no model", false)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||||
|
var body map[string]any
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errObj, ok := body["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "api_error", errObj["type"])
|
||||||
|
require.Equal(t, "no model", errObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Integration: extra retry flow simulation ---
|
||||||
|
|
||||||
|
func TestExtraRetryFlowSimulation(t *testing.T) {
|
||||||
|
// Simulate the full extra retry flow logic
|
||||||
|
maxAccountSwitches := 3
|
||||||
|
maxExtraRetries := 2
|
||||||
|
switchCount := 0
|
||||||
|
antigravityExtraCount := 0
|
||||||
|
|
||||||
|
type attempt struct {
|
||||||
|
platform string
|
||||||
|
isFailover bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate: 3 default retries (all fail), then 2 extra retries (all fail), then exhausted
|
||||||
|
attempts := []attempt{
|
||||||
|
{service.PlatformAntigravity, true}, // switchCount 0 -> 1
|
||||||
|
{service.PlatformGemini, true}, // switchCount 1 -> 2
|
||||||
|
{service.PlatformAntigravity, true}, // switchCount 2 -> 3 (reaches max)
|
||||||
|
{service.PlatformAntigravity, true}, // extra retry 1
|
||||||
|
{service.PlatformAntigravity, true}, // extra retry 2
|
||||||
|
{service.PlatformAntigravity, true}, // extra retry 3 -> exhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
var exhausted bool
|
||||||
|
var skipped int
|
||||||
|
|
||||||
|
for _, a := range attempts {
|
||||||
|
if exhausted {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extra retry phase: skip non-Antigravity
|
||||||
|
if switchCount >= maxAccountSwitches && a.platform != service.PlatformAntigravity {
|
||||||
|
skipped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.isFailover {
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
antigravityExtraCount++
|
||||||
|
if antigravityExtraCount > maxExtraRetries {
|
||||||
|
exhausted = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// extra retry delay + continue
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 3, switchCount, "should have 3 default retries")
|
||||||
|
require.Equal(t, 3, antigravityExtraCount, "counter incremented 3 times")
|
||||||
|
require.True(t, exhausted, "should be exhausted after exceeding max extra retries")
|
||||||
|
require.Equal(t, 0, skipped, "no non-antigravity accounts in this simulation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtraRetryFlowSimulation_SkipsNonAntigravity(t *testing.T) {
|
||||||
|
maxAccountSwitches := 2
|
||||||
|
switchCount := 2 // already past default retries
|
||||||
|
antigravityExtraCount := 0
|
||||||
|
maxExtraRetries := 5
|
||||||
|
|
||||||
|
type accountSelection struct {
|
||||||
|
platform string
|
||||||
|
}
|
||||||
|
|
||||||
|
selections := []accountSelection{
|
||||||
|
{service.PlatformGemini}, // should be skipped
|
||||||
|
{service.PlatformAnthropic}, // should be skipped
|
||||||
|
{service.PlatformAntigravity}, // should be attempted
|
||||||
|
}
|
||||||
|
|
||||||
|
var skippedCount int
|
||||||
|
var attemptedCount int
|
||||||
|
|
||||||
|
for _, sel := range selections {
|
||||||
|
if switchCount >= maxAccountSwitches && sel.platform != service.PlatformAntigravity {
|
||||||
|
skippedCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Simulate failover
|
||||||
|
antigravityExtraCount++
|
||||||
|
if antigravityExtraCount > maxExtraRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attemptedCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 2, skippedCount, "gemini and anthropic accounts should be skipped")
|
||||||
|
require.Equal(t, 1, attemptedCount, "only antigravity account should be attempted")
|
||||||
|
require.Equal(t, 1, antigravityExtraCount)
|
||||||
|
}
|
||||||
@@ -323,6 +323,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
|
antigravityExtraCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||||
@@ -340,6 +341,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
|
// 额外重试阶段:跳过非 Antigravity 账号
|
||||||
|
if switchCount >= maxAccountSwitches && account.Platform != service.PlatformAntigravity {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||||
@@ -424,15 +434,23 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverErr = failoverErr
|
// 默认重试用完,进入 Antigravity 额外重试
|
||||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
antigravityExtraCount++
|
||||||
|
if antigravityExtraCount > h.antigravityExtraRetries {
|
||||||
|
h.handleGeminiFailoverExhausted(c, failoverErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverErr = failoverErr
|
log.Printf("Gemini account %d: antigravity extra retry %d/%d", account.ID, antigravityExtraCount, h.antigravityExtraRetries)
|
||||||
|
if !sleepFixedDelay(c.Request.Context(), antigravityExtraRetryDelay) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
|
|||||||
Reference in New Issue
Block a user