mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-09 17:44:46 +08:00
Merge release/custom-0.1.83 into release/custom-0.1.84
This commit is contained in:
@@ -1 +1 @@
|
||||
0.1.83
|
||||
0.1.84.1
|
||||
|
||||
@@ -883,6 +883,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.max_account_switches", 10)
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
|
||||
@@ -133,6 +133,13 @@ type BulkUpdateAccountsRequest struct {
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||
type CheckMixedChannelRequest struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*dto.Account
|
||||
@@ -283,6 +290,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
|
||||
// POST /api/v1/admin/accounts/check-mixed-channel
|
||||
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
|
||||
var req CheckMixedChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.GroupIDs) == 0 {
|
||||
response.Success(c, gin.H{"has_risk": false})
|
||||
return
|
||||
}
|
||||
|
||||
accountID := int64(0)
|
||||
if req.AccountID != nil {
|
||||
accountID = *req.AccountID
|
||||
}
|
||||
|
||||
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
|
||||
if err != nil {
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
response.Success(c, gin.H{
|
||||
"has_risk": true,
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"has_risk": false})
|
||||
}
|
||||
|
||||
// Create handles creating a new account
|
||||
// POST /api/v1/admin/accounts
|
||||
func (h *AccountHandler) Create(c *gin.Context) {
|
||||
@@ -319,17 +370,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -383,17 +427,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
|
||||
router.POST("/api/v1/admin/accounts", accountHandler.Create)
|
||||
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"platform": "antigravity",
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, false, data["has_risk"])
|
||||
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
|
||||
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
|
||||
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
|
||||
}
|
||||
|
||||
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.checkMixedErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"platform": "antigravity",
|
||||
"group_ids": []int64{27},
|
||||
"account_id": 99,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["has_risk"])
|
||||
require.Equal(t, "mixed_channel_warning", data["error"])
|
||||
details, ok := data["details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(27), details["group_id"])
|
||||
require.Equal(t, "claude-max", details["group_name"])
|
||||
require.Equal(t, "Antigravity", details["current_platform"])
|
||||
require.Equal(t, "Anthropic", details["other_platform"])
|
||||
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.createAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"name": "ag-oauth-1",
|
||||
"platform": "antigravity",
|
||||
"type": "oauth",
|
||||
"credentials": map[string]any{"refresh_token": "rt"},
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
require.False(t, hasRequireConfirmation)
|
||||
}
|
||||
|
||||
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.updateAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
require.False(t, hasRequireConfirmation)
|
||||
}
|
||||
@@ -10,19 +10,27 @@ import (
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
mu sync.Mutex
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
createAccountErr error
|
||||
updateAccountErr error
|
||||
checkMixedErr error
|
||||
lastMixedCheck struct {
|
||||
accountID int64
|
||||
platform string
|
||||
groupIDs []int64
|
||||
}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
|
||||
s.mu.Lock()
|
||||
s.createdAccounts = append(s.createdAccounts, input)
|
||||
s.mu.Unlock()
|
||||
if s.createAccountErr != nil {
|
||||
return nil, s.createAccountErr
|
||||
}
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
if s.updateAccountErr != nil {
|
||||
return nil, s.updateAccountErr
|
||||
}
|
||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
|
||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
s.lastMixedCheck.accountID = currentAccountID
|
||||
s.lastMixedCheck.platform = currentAccountPlatform
|
||||
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
|
||||
return s.checkMixedErr
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
search = strings.TrimSpace(strings.ToLower(search))
|
||||
filtered := make([]service.Proxy, 0, len(s.proxies))
|
||||
|
||||
160
backend/internal/handler/failover_loop.go
Normal file
160
backend/internal/handler/failover_loop.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||
// GatewayService 隐式实现此接口。
|
||||
type TempUnscheduler interface {
|
||||
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
|
||||
}
|
||||
|
||||
// FailoverAction 表示 failover 错误处理后的下一步动作
|
||||
type FailoverAction int
|
||||
|
||||
const (
|
||||
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
|
||||
FailoverContinue FailoverAction = iota
|
||||
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
|
||||
FailoverExhausted
|
||||
// FailoverCanceled context 已取消(调用方应直接 return)
|
||||
FailoverCanceled
|
||||
)
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
singleAccountBackoffDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
// FailoverState 跨循环迭代共享的 failover 状态
|
||||
type FailoverState struct {
|
||||
SwitchCount int
|
||||
MaxSwitches int
|
||||
FailedAccountIDs map[int64]struct{}
|
||||
SameAccountRetryCount map[int64]int
|
||||
LastFailoverErr *service.UpstreamFailoverError
|
||||
ForceCacheBilling bool
|
||||
hasBoundSession bool
|
||||
}
|
||||
|
||||
// NewFailoverState 创建 failover 状态
|
||||
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
|
||||
return &FailoverState{
|
||||
MaxSwitches: maxSwitches,
|
||||
FailedAccountIDs: make(map[int64]struct{}),
|
||||
SameAccountRetryCount: make(map[int64]int),
|
||||
hasBoundSession: hasBoundSession,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
|
||||
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
|
||||
func (s *FailoverState) HandleFailoverError(
|
||||
ctx context.Context,
|
||||
gatewayService TempUnscheduler,
|
||||
accountID int64,
|
||||
platform string,
|
||||
failoverErr *service.UpstreamFailoverError,
|
||||
) FailoverAction {
|
||||
s.LastFailoverErr = failoverErr
|
||||
|
||||
// 缓存计费判断
|
||||
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
|
||||
s.ForceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||
s.SameAccountRetryCount[accountID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
||||
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
|
||||
}
|
||||
|
||||
// 加入失败列表
|
||||
s.FailedAccountIDs[accountID] = struct{}{}
|
||||
|
||||
// 检查是否耗尽
|
||||
if s.SwitchCount >= s.MaxSwitches {
|
||||
return FailoverExhausted
|
||||
}
|
||||
|
||||
// 递增切换计数
|
||||
s.SwitchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
||||
|
||||
// Antigravity 平台换号线性递增延时
|
||||
if platform == service.PlatformAntigravity {
|
||||
delay := time.Duration(s.SwitchCount-1) * time.Second
|
||||
if !sleepWithContext(ctx, delay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
}
|
||||
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
|
||||
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
|
||||
// 清除排除列表、等待退避后重新选号。
|
||||
//
|
||||
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
|
||||
// 返回 FailoverExhausted 时,调用方应返回错误响应。
|
||||
// 返回 FailoverCanceled 时,调用方应直接 return。
|
||||
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
|
||||
if s.LastFailoverErr != nil &&
|
||||
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||
s.SwitchCount <= s.MaxSwitches {
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
||||
singleAccountBackoffDelay, s.SwitchCount)
|
||||
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
||||
s.SwitchCount, s.MaxSwitches)
|
||||
s.FailedAccountIDs = make(map[int64]struct{})
|
||||
return FailoverContinue
|
||||
}
|
||||
return FailoverExhausted
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(d):
|
||||
return true
|
||||
}
|
||||
}
|
||||
732
backend/internal/handler/failover_loop_test.go
Normal file
732
backend/internal/handler/failover_loop_test.go
Normal file
@@ -0,0 +1,732 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
|
||||
type mockTempUnscheduler struct {
|
||||
calls []tempUnscheduleCall
|
||||
}
|
||||
|
||||
type tempUnscheduleCall struct {
|
||||
accountID int64
|
||||
failoverErr *service.UpstreamFailoverError
|
||||
}
|
||||
|
||||
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
|
||||
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
|
||||
return &service.UpstreamFailoverError{
|
||||
StatusCode: statusCode,
|
||||
RetryableOnSameAccount: retryable,
|
||||
ForceCacheBilling: forceBilling,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NewFailoverState 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewFailoverState(t *testing.T) {
|
||||
t.Run("初始化字段正确", func(t *testing.T) {
|
||||
fs := NewFailoverState(5, true)
|
||||
require.Equal(t, 5, fs.MaxSwitches)
|
||||
require.Equal(t, 0, fs.SwitchCount)
|
||||
require.NotNil(t, fs.FailedAccountIDs)
|
||||
require.Empty(t, fs.FailedAccountIDs)
|
||||
require.NotNil(t, fs.SameAccountRetryCount)
|
||||
require.Empty(t, fs.SameAccountRetryCount)
|
||||
require.Nil(t, fs.LastFailoverErr)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
require.True(t, fs.hasBoundSession)
|
||||
})
|
||||
|
||||
t.Run("无绑定会话", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
require.Equal(t, 3, fs.MaxSwitches)
|
||||
require.False(t, fs.hasBoundSession)
|
||||
})
|
||||
|
||||
t.Run("零最大切换次数", func(t *testing.T) {
|
||||
fs := NewFailoverState(0, false)
|
||||
require.Equal(t, 0, fs.MaxSwitches)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepWithContext 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepWithContext(t *testing.T) {
|
||||
t.Run("零时长立即返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), 0)
|
||||
require.True(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("负时长立即返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), -1*time.Second)
|
||||
require.True(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("正常等待后返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok)
|
||||
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
|
||||
require.Less(t, elapsed, 500*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("已取消context立即返回false", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(ctx, 5*time.Second)
|
||||
require.False(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("等待期间context取消返回false", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(ctx, 5*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
require.False(t, ok)
|
||||
require.Less(t, elapsed, 500*time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 基本切换流程
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
|
||||
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
require.Equal(t, err, fs.LastFailoverErr)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
|
||||
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
|
||||
})
|
||||
|
||||
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
|
||||
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1 // 模拟已切换一次
|
||||
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
|
||||
require.Less(t, elapsed, 3*time.Second)
|
||||
})
|
||||
|
||||
t.Run("连续切换直到耗尽", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(2, false)
|
||||
|
||||
// 第一次切换:0→1
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
|
||||
// 第二次切换:1→2
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
|
||||
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
|
||||
err3 := newTestFailoverErr(503, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
|
||||
|
||||
// 验证失败账号列表
|
||||
require.Len(t, fs.FailedAccountIDs, 3)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(300))
|
||||
|
||||
// LastFailoverErr 应为最后一次的错误
|
||||
require.Equal(t, err3, fs.LastFailoverErr)
|
||||
})
|
||||
|
||||
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(0, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Equal(t, 0, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_CacheBilling(t *testing.T) {
|
||||
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("两者均为false时不设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
// 第一次:ForceCacheBilling=true → 设置
|
||||
err1 := newTestFailoverErr(500, false, true)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
|
||||
// 第二次:ForceCacheBilling=false → 仍然保持 true
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
|
||||
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
|
||||
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
|
||||
// 验证等待了 sameAccountRetryDelay (500ms)
|
||||
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
|
||||
// 验证 TempUnschedule 被调用
|
||||
require.Len(t, mock.calls, 1)
|
||||
require.Equal(t, int64(100), mock.calls[0].accountID)
|
||||
require.Equal(t, err, mock.calls[0].failoverErr)
|
||||
})
|
||||
|
||||
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 账号 100 第一次重试
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 账号 200 第一次重试(独立计数)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[200])
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
|
||||
})
|
||||
|
||||
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — TempUnschedule 调用验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Empty(t, mock.calls)
|
||||
})
|
||||
|
||||
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
require.Equal(t, int64(42), mock.calls[0].accountID)
|
||||
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
|
||||
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — Context 取消
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
|
||||
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||
// 重试计数仍应递增
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
})
|
||||
|
||||
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — FailedAccountIDs 跟踪
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
|
||||
t.Run("切换时添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||
require.Len(t, fs.FailedAccountIDs, 2)
|
||||
})
|
||||
|
||||
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(0, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
|
||||
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.NotContains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
|
||||
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — LastFailoverErr 更新
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
|
||||
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.Equal(t, err1, fs.LastFailoverErr)
|
||||
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.Equal(t, err2, fs.LastFailoverErr)
|
||||
})
|
||||
|
||||
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, err, fs.LastFailoverErr)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 综合集成场景
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
// 3. 账号 200 遇到不可重试错误 → 直接切换
|
||||
switchErr := newTestFailoverErr(500, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
|
||||
// 4. 账号 300 遇到不可重试错误 → 再切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 3, fs.SwitchCount)
|
||||
|
||||
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
|
||||
// 最终状态验证
|
||||
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
|
||||
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(2, false)
|
||||
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
// 第一次切换:delay = 0s
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
|
||||
|
||||
// 第二次切换:delay = 1s
|
||||
start = time.Now()
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||
elapsed = time.Since(start)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
|
||||
|
||||
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
|
||||
start = time.Now()
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
|
||||
elapsed = time.Since(start)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
|
||||
})
|
||||
|
||||
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false) // hasBoundSession=false
|
||||
|
||||
// 第一次:ForceCacheBilling=false
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
|
||||
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
|
||||
err2 := newTestFailoverErr(500, false, true)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
|
||||
|
||||
// 第三次:ForceCacheBilling=false,但状态仍保持 true
|
||||
err3 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||
require.True(t, fs.ForceCacheBilling, "不应重置")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 边界条件
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_EdgeCases(t *testing.T) {
|
||||
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(0, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
})
|
||||
|
||||
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, true, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[0])
|
||||
})
|
||||
|
||||
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, true, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
|
||||
})
|
||||
|
||||
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleSelectionExhausted 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleSelectionExhausted(t *testing.T) {
|
||||
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
// LastFailoverErr 为 nil
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
})
|
||||
|
||||
t.Run("非503错误返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
})
|
||||
|
||||
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.FailedAccountIDs[100] = struct{}{}
|
||||
fs.SwitchCount = 1
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(2, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.SwitchCount = 3 // > MaxSwitches(2)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
|
||||
})
|
||||
|
||||
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(ctx)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||
})
|
||||
|
||||
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
|
||||
fs := NewFailoverState(2, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
})
|
||||
}
|
||||
@@ -232,12 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
@@ -247,31 +242,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
@@ -346,8 +338,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
@@ -360,40 +352,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
lastFailoverErr = failoverErr
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
case FailoverExhausted:
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
@@ -421,7 +389,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
}(result, account, userAgent, clientIP, fs.ForceCacheBilling)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -442,41 +410,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
|
||||
retryWithFallback := false
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
@@ -549,8 +509,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
@@ -598,40 +558,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
lastFailoverErr = failoverErr
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
case FailoverExhausted:
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
@@ -659,7 +595,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
}(result, account, userAgent, clientIP, fs.ForceCacheBilling)
|
||||
return
|
||||
}
|
||||
if !retryWithFallback {
|
||||
@@ -893,65 +829,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
||||
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||
delay := time.Duration(switchCount-1) * time.Second
|
||||
if delay <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
||||
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
||||
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
||||
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
||||
// 固定短延时:2s
|
||||
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
const delay = 2 * time.Second
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepAntigravitySingleAccountBackoff 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok, "should return true when context is not canceled")
|
||||
// 固定延迟 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
||||
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.False(t, ok, "should return false when context is canceled")
|
||||
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
||||
// 验证不同 retryCount 都使用固定 2s 延迟
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok)
|
||||
// 即使 retryCount=5,延迟仍然是固定的 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
|
||||
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
|
||||
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
|
||||
|
||||
type fakeSchedulerCache struct {
|
||||
accounts []*service.Account
|
||||
}
|
||||
|
||||
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return f.accounts, true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
|
||||
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
|
||||
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
|
||||
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
|
||||
|
||||
type fakeGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
|
||||
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeConcurrencyCache struct{}
|
||||
|
||||
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
|
||||
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
|
||||
schedulerCache := &fakeSchedulerCache{accounts: accounts}
|
||||
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
|
||||
|
||||
gwSvc := service.NewGatewayService(
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache (disable sticky)
|
||||
nil, // cfg
|
||||
schedulerSnapshot,
|
||||
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
|
||||
nil, // billingService
|
||||
nil, // rateLimitService
|
||||
nil, // billingCacheService
|
||||
nil, // identityService
|
||||
nil, // httpUpstream
|
||||
nil, // deferredService
|
||||
nil, // claudeTokenProvider
|
||||
nil, // sessionLimitCache
|
||||
nil, // digestStore
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
h := &GatewayHandler{
|
||||
gatewayService: gwSvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
concurrencyHelper: concurrencyHelper,
|
||||
// 这些字段对本测试不敏感,保持较小即可
|
||||
maxAccountSwitches: 1,
|
||||
maxAccountSwitchesGemini: 1,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
billingCacheSvc.Stop()
|
||||
}
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2001)
|
||||
accountID := int64(1001)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
|
||||
c.Request = req
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3001,
|
||||
UserID: 4001,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4001,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
|
||||
content, ok := resp["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
first, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "New Conversation", first["text"])
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2002)
|
||||
accountID := int64(1002)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-2",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
|
||||
// - 写入 request.Context(Service读取)
|
||||
// - 写入 gin.Context(Handler快速读取)
|
||||
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
|
||||
req = req.WithContext(ctx)
|
||||
c.Request = req
|
||||
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3002,
|
||||
UserID: 4002,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4002,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
}
|
||||
@@ -321,11 +321,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
@@ -335,27 +331,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||
return
|
||||
}
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
@@ -429,8 +422,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||
@@ -443,24 +436,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverErr = failoverErr
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
@@ -506,7 +491,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
}(result, account, userAgent, clientIP, fs.ForceCacheBilling)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ const (
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||
UserAgent = "antigravity/1.15.8 windows/amd64"
|
||||
UserAgent = "antigravity/1.16.5 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
|
||||
@@ -208,6 +208,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
|
||||
@@ -50,6 +50,7 @@ type AdminService interface {
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
||||
@@ -1706,6 +1707,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
|
||||
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||
return
|
||||
|
||||
@@ -85,7 +85,6 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||
)
|
||||
@@ -1311,6 +1310,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||
billingModel := mappedModel
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -1372,6 +1372,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -1620,7 +1624,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Model: billingModel, // 使用映射模型用于计费和日志
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -1974,6 +1978,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if mappedModel == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||||
}
|
||||
billingModel := mappedModel
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -2044,6 +2049,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -2199,7 +2208,7 @@ handleSuccess:
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Model: billingModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -2644,7 +2653,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
defaultDur := s.getDefaultRateLimitDuration()
|
||||
|
||||
// 尝试解析模型 key 并设置模型级限流
|
||||
modelKey := resolveAntigravityModelKey(requestedModel)
|
||||
//
|
||||
// 注意:requestedModel 可能是“映射前”的请求模型名(例如 claude-opus-4-6),
|
||||
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
|
||||
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
|
||||
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||
if strings.TrimSpace(modelKey) == "" {
|
||||
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
|
||||
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
|
||||
modelKey = resolveAntigravityModelKey(requestedModel)
|
||||
}
|
||||
if modelKey != "" {
|
||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||||
@@ -3875,7 +3893,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
@@ -3928,7 +3945,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3969,7 +3986,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
|
||||
@@ -133,6 +133,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
type antigravitySettingRepoStub struct{}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -159,8 +189,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -417,6 +448,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
|
||||
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
|
||||
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hello"},
|
||||
},
|
||||
"max_tokens": 16,
|
||||
"stream": true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
const mappedModel = "gemini-3-pro-high"
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "acc-forward-billing",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-5": mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
|
||||
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
const mappedModel = "gemini-3-pro-high"
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "acc-gemini-billing",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-flash": mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
// --- 流式 happy path 测试 ---
|
||||
|
||||
// TestStreamUpstreamResponse_NormalComplete
|
||||
|
||||
@@ -191,6 +191,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
|
||||
|
||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底
|
||||
// 场景:requestedModel 会被默认映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking)
|
||||
body := buildGeminiRateLimitBody("5s")
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
|
||||
|
||||
require.Nil(t, result)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, ok, "value should be a bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
|
||||
@@ -890,6 +890,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Priority: 1,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Priority: 2,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}},
|
||||
},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号")
|
||||
|
||||
acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "supporting model")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(50)
|
||||
@@ -1065,6 +1114,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
model: "gemini-2.5-pro",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -2549,10 +2549,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
}
|
||||
// Gemini API Key 账户直接透传,由上游判断模型是否支持
|
||||
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
|
||||
return true
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user