Merge pull request #627 from touwaeriol/pr/bugfixes-and-enhancements

feat: 反重力(Antigravity)增强、Failover 重构及新模型支持
This commit is contained in:
Wesley Liddick
2026-02-25 08:30:25 +08:00
committed by GitHub
41 changed files with 2591 additions and 651 deletions

View File

@@ -1158,6 +1158,7 @@ func setDefaults() {
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))

View File

@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
// Claude 详细版本 ID 映射
@@ -89,16 +90,18 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3.1-pro-high",
"gemini-3-pro-low": "gemini-3.1-pro-low",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
// Gemini 3.1 透传
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3 preview 映射
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3.1-pro-high",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// Gemini 3.1 白名单
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3.1 preview 映射
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",

View File

@@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type CheckMixedChannelRequest struct {
Platform string `json:"platform" binding:"required"`
GroupIDs []int64 `json:"group_ids"`
AccountID *int64 `json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*dto.Account
@@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// Create handles creating a new account
// POST /api/v1/admin/accounts
func (h *AccountHandler) Create(c *gin.Context) {
@@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
@@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}

View File

@@ -0,0 +1,147 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}

View File

@@ -10,19 +10,27 @@ import (
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
mu sync.Mutex
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))

View File

@@ -0,0 +1,160 @@
package handler
import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
// GatewayService 隐式实现此接口。
type TempUnscheduler interface {
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
}
// FailoverAction 表示 failover 错误处理后的下一步动作
type FailoverAction int
const (
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue
FailoverContinue FailoverAction = iota
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
FailoverExhausted
// FailoverCanceled context 已取消(调用方应直接 return
FailoverCanceled
)
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s
// Handler 层只需短暂间隔后重新进入 Service 层即可。
singleAccountBackoffDelay = 2 * time.Second
)
// FailoverState 跨循环迭代共享的 failover 状态
type FailoverState struct {
SwitchCount int
MaxSwitches int
FailedAccountIDs map[int64]struct{}
SameAccountRetryCount map[int64]int
LastFailoverErr *service.UpstreamFailoverError
ForceCacheBilling bool
hasBoundSession bool
}
// NewFailoverState 创建 failover 状态
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
return &FailoverState{
MaxSwitches: maxSwitches,
FailedAccountIDs: make(map[int64]struct{}),
SameAccountRetryCount: make(map[int64]int),
hasBoundSession: hasBoundSession,
}
}
// HandleFailoverError 处理 UpstreamFailoverError返回下一步动作。
// 包含缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
func (s *FailoverState) HandleFailoverError(
ctx context.Context,
gatewayService TempUnscheduler,
accountID int64,
platform string,
failoverErr *service.UpstreamFailoverError,
) FailoverAction {
s.LastFailoverErr = failoverErr
// 缓存计费判断
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
s.ForceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
s.SameAccountRetryCount[accountID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
if !sleepWithContext(ctx, sameAccountRetryDelay) {
return FailoverCanceled
}
return FailoverContinue
}
// 同账号重试用尽,执行临时封禁
if failoverErr.RetryableOnSameAccount {
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
}
// 加入失败列表
s.FailedAccountIDs[accountID] = struct{}{}
// 检查是否耗尽
if s.SwitchCount >= s.MaxSwitches {
return FailoverExhausted
}
// 递增切换计数
s.SwitchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d",
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
// Antigravity 平台换号线性递增延时
if platform == service.PlatformAntigravity {
delay := time.Duration(s.SwitchCount-1) * time.Second
if !sleepWithContext(ctx, delay) {
return FailoverCanceled
}
}
return FailoverContinue
}
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
// 清除排除列表、等待退避后重新选号。
//
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
// 返回 FailoverExhausted 时,调用方应返回错误响应。
// 返回 FailoverCanceled 时,调用方应直接 return。
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
if s.LastFailoverErr != nil &&
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
s.SwitchCount <= s.MaxSwitches {
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
singleAccountBackoffDelay, s.SwitchCount)
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
return FailoverCanceled
}
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
s.SwitchCount, s.MaxSwitches)
s.FailedAccountIDs = make(map[int64]struct{})
return FailoverContinue
}
return FailoverExhausted
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
func sleepWithContext(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}

View File

@@ -0,0 +1,732 @@
package handler
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mock
// ---------------------------------------------------------------------------
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
type mockTempUnscheduler struct {
calls []tempUnscheduleCall
}
type tempUnscheduleCall struct {
accountID int64
failoverErr *service.UpstreamFailoverError
}
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
}
// ---------------------------------------------------------------------------
// Helper
// ---------------------------------------------------------------------------
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
return &service.UpstreamFailoverError{
StatusCode: statusCode,
RetryableOnSameAccount: retryable,
ForceCacheBilling: forceBilling,
}
}
// ---------------------------------------------------------------------------
// NewFailoverState 测试
// ---------------------------------------------------------------------------
func TestNewFailoverState(t *testing.T) {
t.Run("初始化字段正确", func(t *testing.T) {
fs := NewFailoverState(5, true)
require.Equal(t, 5, fs.MaxSwitches)
require.Equal(t, 0, fs.SwitchCount)
require.NotNil(t, fs.FailedAccountIDs)
require.Empty(t, fs.FailedAccountIDs)
require.NotNil(t, fs.SameAccountRetryCount)
require.Empty(t, fs.SameAccountRetryCount)
require.Nil(t, fs.LastFailoverErr)
require.False(t, fs.ForceCacheBilling)
require.True(t, fs.hasBoundSession)
})
t.Run("无绑定会话", func(t *testing.T) {
fs := NewFailoverState(3, false)
require.Equal(t, 3, fs.MaxSwitches)
require.False(t, fs.hasBoundSession)
})
t.Run("零最大切换次数", func(t *testing.T) {
fs := NewFailoverState(0, false)
require.Equal(t, 0, fs.MaxSwitches)
})
}
// ---------------------------------------------------------------------------
// sleepWithContext 测试
// ---------------------------------------------------------------------------
func TestSleepWithContext(t *testing.T) {
t.Run("零时长立即返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), 0)
require.True(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("负时长立即返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), -1*time.Second)
require.True(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("正常等待后返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
elapsed := time.Since(start)
require.True(t, ok)
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
require.Less(t, elapsed, 500*time.Millisecond)
})
t.Run("已取消context立即返回false", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := sleepWithContext(ctx, 5*time.Second)
require.False(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("等待期间context取消返回false", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := sleepWithContext(ctx, 5*time.Second)
elapsed := time.Since(start)
require.False(t, ok)
require.Less(t, elapsed, 500*time.Millisecond)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 基本切换流程
// ---------------------------------------------------------------------------
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
require.Equal(t, err, fs.LastFailoverErr)
require.False(t, fs.ForceCacheBilling)
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
})
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
// switchCount 从 0→1 时sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
})
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
// switchCount 从 1→2 时sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1 // 模拟已切换一次
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
require.Less(t, elapsed, 3*time.Second)
})
t.Run("连续切换直到耗尽", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(2, false)
// 第一次切换0→1
err1 := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
// 第二次切换1→2
err2 := newTestFailoverErr(502, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
// 第三次已耗尽SwitchCount(2) >= MaxSwitches(2)
err3 := newTestFailoverErr(503, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
require.Equal(t, FailoverExhausted, action)
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
// 验证失败账号列表
require.Len(t, fs.FailedAccountIDs, 3)
require.Contains(t, fs.FailedAccountIDs, int64(100))
require.Contains(t, fs.FailedAccountIDs, int64(200))
require.Contains(t, fs.FailedAccountIDs, int64(300))
// LastFailoverErr 应为最后一次的错误
require.Equal(t, err3, fs.LastFailoverErr)
})
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(0, false)
err := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverExhausted, action)
require.Equal(t, 0, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
// ---------------------------------------------------------------------------
func TestHandleFailoverError_CacheBilling(t *testing.T) {
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
err := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.True(t, fs.ForceCacheBilling)
})
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.True(t, fs.ForceCacheBilling)
})
t.Run("两者均为false时不设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.False(t, fs.ForceCacheBilling)
})
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
// 第一次ForceCacheBilling=true → 设置
err1 := newTestFailoverErr(500, false, true)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.True(t, fs.ForceCacheBilling)
// 第二次ForceCacheBilling=false → 仍然保持 true
err2 := newTestFailoverErr(502, false, false)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
// ---------------------------------------------------------------------------
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
// 验证等待了 sameAccountRetryDelay (500ms)
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
require.Less(t, elapsed, 2*time.Second)
})
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
// 第二次
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
})
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次、第二次重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
// 验证 TempUnschedule 被调用
require.Len(t, mock.calls, 1)
require.Equal(t, int64(100), mock.calls[0].accountID)
require.Equal(t, err, mock.calls[0].failoverErr)
})
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
err := newTestFailoverErr(400, true, false)
// 账号 100 第一次重试
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
// 账号 200 第一次重试(独立计数)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[200])
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
})
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
err := newTestFailoverErr(400, true, false)
// 耗尽账号 100 的重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
// 第三次: 重试耗尽 → 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
// 再次遇到账号 100计数仍为 2条件不满足 → 直接切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — TempUnschedule 调用验证
// ---------------------------------------------------------------------------
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Empty(t, mock.calls)
})
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(502, true, false)
// 耗尽重试
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
require.Len(t, mock.calls, 1)
require.Equal(t, int64(42), mock.calls[0].accountID)
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — Context 取消
// ---------------------------------------------------------------------------
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
// 重试计数仍应递增
require.Equal(t, 1, fs.SameAccountRetryCount[100])
})
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
err := newTestFailoverErr(500, false, false)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — FailedAccountIDs 跟踪
// ---------------------------------------------------------------------------
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
t.Run("切换时添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Contains(t, fs.FailedAccountIDs, int64(100))
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
require.Contains(t, fs.FailedAccountIDs, int64(200))
require.Len(t, fs.FailedAccountIDs, 2)
})
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(0, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Equal(t, FailoverExhausted, action)
require.Contains(t, fs.FailedAccountIDs, int64(100))
})
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
require.Equal(t, FailoverContinue, action)
require.NotContains(t, fs.FailedAccountIDs, int64(100))
})
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — LastFailoverErr 更新
// ---------------------------------------------------------------------------
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err1 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.Equal(t, err1, fs.LastFailoverErr)
err2 := newTestFailoverErr(502, false, false)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.Equal(t, err2, fs.LastFailoverErr)
})
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, err, fs.LastFailoverErr)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 综合集成场景
// ---------------------------------------------------------------------------
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
retryErr := newTestFailoverErr(400, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Len(t, mock.calls, 1)
// 3. 账号 200 遇到不可重试错误 → 直接切换
switchErr := newTestFailoverErr(500, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
// 4. 账号 300 遇到不可重试错误 → 再切换
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 3, fs.SwitchCount)
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
require.Equal(t, FailoverExhausted, action)
// 最终状态验证
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
require.True(t, fs.ForceCacheBilling)
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
})
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(2, false)
err := newTestFailoverErr(500, false, false)
// 第一次切换delay = 0s
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
// 第二次切换delay = 1s
start = time.Now()
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
elapsed = time.Since(start)
require.Equal(t, FailoverContinue, action)
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
start = time.Now()
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
elapsed = time.Since(start)
require.Equal(t, FailoverExhausted, action)
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
})
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false) // hasBoundSession=false
// 第一次ForceCacheBilling=false
err1 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.False(t, fs.ForceCacheBilling)
// 第二次ForceCacheBilling=trueAntigravity 粘性会话切换)
err2 := newTestFailoverErr(500, false, true)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
// 第三次ForceCacheBilling=false但状态仍保持 true
err3 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
require.True(t, fs.ForceCacheBilling, "不应重置")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 边界条件
// ---------------------------------------------------------------------------
func TestHandleFailoverError_EdgeCases(t *testing.T) {
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(0, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
})
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[0])
})
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, true, false)
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
})
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
})
}
// ---------------------------------------------------------------------------
// HandleSelectionExhausted 测试
// ---------------------------------------------------------------------------
func TestHandleSelectionExhausted(t *testing.T) {
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(3, false)
// LastFailoverErr 为 nil
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverExhausted, action)
})
t.Run("非503错误返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverExhausted, action)
})
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.FailedAccountIDs[100] = struct{}{}
fs.SwitchCount = 1
start := time.Now()
action := fs.HandleSelectionExhausted(context.Background())
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
require.Less(t, elapsed, 5*time.Second)
})
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(2, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.SwitchCount = 3 // > MaxSwitches(2)
start := time.Now()
action := fs.HandleSelectionExhausted(context.Background())
elapsed := time.Since(start)
require.Equal(t, FailoverExhausted, action)
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
})
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
action := fs.HandleSelectionExhausted(ctx)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
})
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
fs := NewFailoverState(2, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.SwitchCount = 2 // == MaxSwitches条件是 <=,仍可重试
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverContinue, action)
})
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
if fs.LastFailoverErr != nil {
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
return
case FailoverCanceled:
return
}
switchCount++
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: forceCacheBilling,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
if fs.LastFailoverErr != nil {
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
return
case FailoverCanceled:
return
}
switchCount++
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: forceCacheBilling,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
).Error("gateway.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Bool("fallback_used", fallbackUsed),
)
return
}
if !retryWithFallback {
@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503MODEL_CAPACITY_EXHAUSTED时使用
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
// 固定短延时2s
// Service 层已经在原地等待了足够长的时间retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
logger.L().With(
zap.String("component", "handler.gateway.failover"),
zap.Duration("delay", delay),
zap.Int("retry_count", retryCount),
).Info("gateway.single_account_backoff_waiting")
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody

View File

@@ -1,51 +0,0 @@
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.True(t, ok, "should return true when context is not canceled")
// 固定延迟 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
}
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.False(t, ok, "should return false when context is canceled")
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
}
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
// 验证不同 retryCount 都使用固定 2s 延迟
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
elapsed := time.Since(start)
require.True(t, ok)
// 即使 retryCount=5延迟仍然是固定的 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
require.Less(t, elapsed, 5*time.Second)
}

View File

@@ -0,0 +1,340 @@
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 目标严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
type fakeSchedulerCache struct {
accounts []*service.Account
}
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
return f.accounts, true, nil
}
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
return nil
}
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
return nil, nil
}
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
return nil
}
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
type fakeGroupRepo struct {
group *service.Group
}
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
return nil, nil
}
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
return nil
}
type fakeConcurrencyCache struct{}
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper()
schedulerCache := &fakeSchedulerCache{accounts: accounts}
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
gwSvc := service.NewGatewayService(
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache (disable sticky)
nil, // cfg
schedulerSnapshot,
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
nil, // billingService
nil, // rateLimitService
nil, // billingCacheService
nil, // identityService
nil, // httpUpstream
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
nil, // digestStore
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
h := &GatewayHandler{
gatewayService: gwSvc,
billingCacheService: billingCacheSvc,
concurrencyHelper: concurrencyHelper,
// 这些字段对本测试不敏感,保持较小即可
maxAccountSwitches: 1,
maxAccountSwitchesGemini: 1,
}
cleanup := func() {
billingCacheSvc.Stop()
}
return h, cleanup
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2001)
accountID := int64(1001)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAnthropic, // /v1/messagesClaude兼容入口
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-1",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Extra: map[string]any{
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
c.Request = req
apiKey := &service.APIKey{
ID: 3001,
UserID: 4001,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4001,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "New Conversation", first["text"])
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2002)
accountID := int64(1002)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-2",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
// - 写入 request.ContextService读取
// - 写入 gin.ContextHandler快速读取
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
req = req.WithContext(ctx)
c.Request = req
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
apiKey := &service.APIKey{
ID: 3002,
UserID: 4002,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4002,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
}

View File

@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
reqLog.Warn("gemini.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
}
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch failoverAction {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
case FailoverCanceled:
return
}
lastFailoverErr = failoverErr
switchCount++
reqLog.Warn("gemini.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: forceCacheBilling,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
})
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Int("switch_count", fs.SwitchCount),
)
return
}

View File

@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_成功(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求方法
@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) {
}
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("")
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
}
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_MockServer(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) {
}
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("")
_, err := client.RefreshToken(context.Background(), "refresh-tok")
@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client {
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
}
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
}
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
}
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second) // 模拟慢响应
@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
}
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
}
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
}
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)

View File

@@ -23,11 +23,9 @@ const (
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = ""
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
// 出于安全原因,该值不得硬编码入库。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri用户需手动复制 code
@@ -51,14 +49,21 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2
var defaultUserAgentVersion = "1.84.2"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4
var defaultUserAgentVersion = "1.18.4"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
func init() {
// 从环境变量读取版本号,未设置则使用默认值
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
defaultUserAgentVersion = version
}
// 从环境变量读取 client_secret未设置则使用默认值
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
defaultClientSecret = secret
}
}
// GetUserAgent 返回当前配置的 User-Agent
@@ -67,14 +72,9 @@ func GetUserAgent() string {
}
func getClientSecret() (string, error) {
if v := strings.TrimSpace(ClientSecret); v != "" {
if v := strings.TrimSpace(defaultClientSecret); v != "" {
return v, nil
}
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
if vv := strings.TrimSpace(v); vv != "" {
return vv, nil
}
}
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/hex"
"net/url"
"os"
"strings"
"testing"
"time"
@@ -17,8 +18,14 @@ import (
// ---------------------------------------------------------------------------
func TestGetClientSecret_环境变量设置(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
// 需要重新触发 init 逻辑:手动从环境变量读取
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) {
}
func TestGetClientSecret_环境变量为空(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量为空时应返回错误")
t.Fatal("defaultClientSecret 为空时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) {
}
func TestGetClientSecret_环境变量未设置(t *testing.T) {
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
// 明确设置再取消,确保环境变量不存在
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量未设置时应返回错误")
t.Fatal("defaultClientSecret 为空时应返回错误")
}
}
func TestGetClientSecret_环境变量含空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
old := defaultClientSecret
defaultClientSecret = " "
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量仅含空格时应返回错误")
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
}
}
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
old := defaultClientSecret
defaultClientSecret = " valid-secret "
t.Cleanup(func() { defaultClientSecret = old })
secret, err := getClientSecret()
if err != nil {
@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) {
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
t.Errorf("ClientID 不匹配: got %s", ClientID)
}
if ClientSecret != "" {
t.Error("ClientSecret 应为空字符串")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
}
if secret != "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" {
t.Errorf("默认 client_secret 不匹配: got %s", secret)
}
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if GetUserAgent() != "antigravity/1.84.2 windows/amd64" {
if GetUserAgent() != "antigravity/1.18.4 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {

View File

@@ -206,6 +206,7 @@ type modelInfo struct {
var modelInfoMap = map[string]modelInfo{
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
}

View File

@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)

View File

@@ -0,0 +1,66 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
expected bool
}{
{
name: "nil credentials",
credentials: nil,
expected: false,
},
{
name: "empty map",
credentials: map[string]any{},
expected: false,
},
{
name: "field not present",
credentials: map[string]any{"access_token": "tok"},
expected: false,
},
{
name: "field is true",
credentials: map[string]any{"intercept_warmup_requests": true},
expected: true,
},
{
name: "field is false",
credentials: map[string]any{"intercept_warmup_requests": false},
expected: false,
},
{
name: "field is string true",
credentials: map[string]any{"intercept_warmup_requests": "true"},
expected: false,
},
{
name: "field is int 1",
credentials: map[string]any{"intercept_warmup_requests": 1},
expected: false,
},
{
name: "field is nil",
credentials: map[string]any{"intercept_warmup_requests": nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Credentials: tt.credentials}
result := a.IsInterceptWarmupEnabled()
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -54,6 +54,7 @@ type AdminService interface {
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
}
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 {
return

View File

@@ -87,7 +87,6 @@ var (
)
const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
billingModel := mappedModel
// 获取 access_token
if s.tokenProvider == nil {
@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
ForceCacheBilling: switchErr.IsStickySession,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
resp := result.resp
@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Model: billingModel, // 使用映射模型用于计费和日志
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if mappedModel == "" {
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
}
billingModel := mappedModel
// 获取 access_token
if s.tokenProvider == nil {
@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
ForceCacheBilling: switchErr.IsStickySession,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
resp := result.resp
@@ -2197,7 +2206,7 @@ handleSuccess:
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Model: billingModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
defaultDur := s.getDefaultRateLimitDuration()
// 尝试解析模型 key 并设置模型级限流
modelKey := resolveAntigravityModelKey(requestedModel)
//
// 注意requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
// 因此这里必须写入最终模型 key确保后续调度能正确避开已限流模型。
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
if strings.TrimSpace(modelKey) == "" {
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
modelKey = resolveAntigravityModelKey(requestedModel)
}
if modelKey != "" {
ra := s.resolveResetTime(resetAt, defaultDur)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
@@ -3881,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
@@ -3934,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
Model: originalModel,
}, nil
}
@@ -3975,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Model: originalModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,

View File

@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return s.resp, s.err
}
type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
return "", ErrSettingNotFound
}
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
account := &Account{
@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
// 验证Antigravity Claude 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hello"},
},
"max_tokens": 16,
"stream": true,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 5,
Name: "acc-forward-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"claude-sonnet-4-5": mappedModel,
},
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
// 验证Antigravity Gemini 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 6,
Name: "acc-gemini-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"gemini-2.5-flash": mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {

View File

@@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
},
// 3. 默认映射中的透传(映射到自己)
{
name: "默认映射透传 - claude-sonnet-4-6",
requestedModel: "claude-sonnet-4-6",
accountMapping: nil,
expected: "claude-sonnet-4-6",
},
{
name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",

View File

@@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景
// 验证requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
require.Nil(t, result)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {

View File

@@ -133,6 +133,18 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
SupportsCacheBreakdown: false,
}
// Claude 4.6 Opus (与4.5同价)
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
CacheCreationPricePerToken: 2e-6, // $2 per MTok
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
SupportsCacheBreakdown: false,
}
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -141,6 +153,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
// 按模型系列匹配
if strings.Contains(modelLower, "opus") {
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
return s.fallbackPrices["claude-opus-4.6"]
}
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
return s.fallbackPrices["claude-opus-4.5"]
}
@@ -158,6 +173,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
return s.fallbackPrices["claude-3-haiku"]
}
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
return s.fallbackPrices["gemini-3.1-pro"]
}
// 默认使用Sonnet价格
return s.fallbackPrices["claude-sonnet-4"]

View File

@@ -895,6 +895,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t
require.Equal(t, int64(2), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}},
},
{
ID: 2,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Priority: 2,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号")
acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(50)
@@ -1070,6 +1119,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
},
},
model: "gemini-2.5-flash",
expected: false,
},
{
name: "Gemini平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
},
},
model: "gemini-2.5-pro",
expected: true,
},
}
for _, tt := range tests {

View File

@@ -2825,10 +2825,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
return true
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}

View File

@@ -107,12 +107,12 @@ func TestIsModelRateLimited(t *testing.T) {
expected: true,
},
{
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3.1-pro-high",
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3.1-pro-high": map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},

View File

@@ -0,0 +1,42 @@
-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts
--
-- Background:
-- Antigravity now supports claude-sonnet-4-6
--
-- Strategy:
-- Directly overwrite the entire model_mapping with updated mappings
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -0,0 +1,45 @@
-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping
--
-- Background:
-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low
--
-- Strategy:
-- Directly overwrite the entire model_mapping with updated mappings
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -15,7 +15,9 @@ import type {
AccountUsageStatsResponse,
TempUnschedulableStatus,
AdminDataPayload,
AdminDataImportResult
AdminDataImportResult,
CheckMixedChannelRequest,
CheckMixedChannelResponse
} from '@/types'
/**
@@ -133,6 +135,16 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise
return data
}
/**
* Check mixed-channel risk for account-group binding.
*/
export async function checkMixedChannelRisk(
payload: CheckMixedChannelRequest
): Promise<CheckMixedChannelResponse> {
const { data } = await apiClient.post<CheckMixedChannelResponse>('/admin/accounts/check-mixed-channel', payload)
return data
}
/**
* Delete account
* @param id - Account ID
@@ -535,6 +547,7 @@ export const accountsAPI = {
getById,
create,
update,
checkMixedChannelRisk,
delete: deleteAccount,
toggleStatus,
testAccount,

View File

@@ -77,13 +77,23 @@
</div>
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
<template v-if="activeModelRateLimits.length > 0">
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
<div
v-if="activeModelRateLimits.length > 0"
:class="[
activeModelRateLimits.length <= 4
? 'flex flex-col gap-1'
: activeModelRateLimits.length <= 8
? 'columns-2 gap-x-2'
: 'columns-3 gap-x-2'
]"
>
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative mb-1 break-inside-avoid">
<span
class="inline-flex items-center gap-1 rounded bg-purple-100 px-1.5 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/30 dark:text-purple-400"
>
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
{{ formatScopeName(item.model) }}
<span class="text-[10px] opacity-70">{{ formatModelResetTime(item.reset_at) }}</span>
</span>
<!-- Tooltip -->
<div
@@ -95,7 +105,7 @@
></div>
</div>
</div>
</template>
</div>
<!-- Overload Indicator (529) -->
<div v-if="isOverloaded" class="group relative">
@@ -154,17 +164,50 @@ const activeModelRateLimits = computed(() => {
})
const formatScopeName = (scope: string): string => {
const names: Record<string, string> = {
const aliases: Record<string, string> = {
// Claude 系列
'claude-opus-4-6-thinking': 'COpus46',
'claude-sonnet-4-6': 'CSon46',
'claude-sonnet-4-5': 'CSon45',
'claude-sonnet-4-5-thinking': 'CSon45T',
// Gemini 2.5 系列
'gemini-2.5-flash': 'G25F',
'gemini-2.5-flash-lite': 'G25FL',
'gemini-2.5-flash-thinking': 'G25FT',
'gemini-2.5-pro': 'G25P',
// Gemini 3 系列
'gemini-3-flash': 'G3F',
'gemini-3.1-pro-high': 'G3PH',
'gemini-3.1-pro-low': 'G3PL',
'gemini-3-pro-image': 'G3PI',
// 其他
'gpt-oss-120b-medium': 'GPT120',
'tab_flash_lite_preview': 'TabFL',
// 旧版 scope 别名(兼容)
claude: 'Claude',
claude_sonnet: 'Claude Sonnet',
claude_opus: 'Claude Opus',
claude_haiku: 'Claude Haiku',
claude_sonnet: 'CSon',
claude_opus: 'COpus',
claude_haiku: 'CHaiku',
gemini_text: 'Gemini',
gemini_image: 'Image',
gemini_flash: 'Gemini Flash',
gemini_pro: 'Gemini Pro'
gemini_image: 'GImg',
gemini_flash: 'GFlash',
gemini_pro: 'GPro',
}
return names[scope] || scope
return aliases[scope] || scope
}
const formatModelResetTime = (resetAt: string): string => {
const date = new Date(resetAt)
const now = new Date()
const diffMs = date.getTime() - now.getTime()
if (diffMs <= 0) return ''
const totalSecs = Math.floor(diffMs / 1000)
const h = Math.floor(totalSecs / 3600)
const m = Math.floor((totalSecs % 3600) / 60)
const s = totalSecs % 60
if (h > 0) return `${h}h${m}m`
if (m > 0) return `${m}m${s}s`
return `${s}s`
}
// Computed: is overloaded (529)

View File

@@ -172,12 +172,12 @@
color="purple"
/>
<!-- Claude 4.5 -->
<!-- Claude -->
<UsageProgressBar
v-if="antigravityClaude45UsageFromAPI !== null"
:label="t('admin.accounts.usageWindow.claude45')"
:utilization="antigravityClaude45UsageFromAPI.utilization"
:resets-at="antigravityClaude45UsageFromAPI.resetTime"
v-if="antigravityClaudeUsageFromAPI !== null"
:label="t('admin.accounts.usageWindow.claude')"
:utilization="antigravityClaudeUsageFromAPI.utilization"
:resets-at="antigravityClaudeUsageFromAPI.resetTime"
color="amber"
/>
</div>
@@ -400,9 +400,12 @@ const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(
// Gemini 3 Image from API
const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image']))
// Claude 4.5 from API
const antigravityClaude45UsageFromAPI = computed(() =>
getAntigravityUsageFromAPI(['claude-sonnet-4-5', 'claude-opus-4-5-thinking'])
// Claude from API (all Claude model variants)
const antigravityClaudeUsageFromAPI = computed(() =>
getAntigravityUsageFromAPI([
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
'claude-sonnet-4-6', 'claude-opus-4-6-thinking',
])
)
// Antigravity 账户类型(从 load_code_assist 响应中提取)

View File

@@ -209,7 +209,7 @@
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in modelMappings"
:key="getModelMappingKey(mapping)"
:key="index"
class="flex items-center gap-2"
>
<input
@@ -654,7 +654,7 @@ import Select from '@/components/common/Select.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import Icon from '@/components/icons/Icon.vue'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist'
interface Props {
show: boolean
@@ -696,7 +696,6 @@ const baseUrl = ref('')
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([])
const modelMappings = ref<ModelMapping[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
@@ -707,7 +706,7 @@ const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
// All models list (combined Anthropic + OpenAI)
// All models list (combined Anthropic + OpenAI + Gemini)
const allModels = [
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
@@ -719,17 +718,21 @@ const allModels = [
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' }
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
]
// Preset mappings (combined Anthropic + OpenAI)
// Preset mappings (combined Anthropic + OpenAI + Gemini)
const presetMappings = [
{
label: 'Sonnet 4',
@@ -796,12 +799,6 @@ const presetMappings = [
to: 'claude-sonnet-4-5-20250929',
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
},
{
label: 'GPT-5.3 Codex Spark',
from: 'gpt-5.3-codex-spark',
to: 'gpt-5.3-codex-spark',
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
},
{
label: 'GPT-5.2',
from: 'gpt-5.2-2025-12-11',
@@ -938,23 +935,11 @@ const removeErrorCode = (code: number) => {
}
const buildModelMappingObject = (): Record<string, string> | null => {
const mapping: Record<string, string> = {}
if (modelRestrictionMode.value === 'whitelist') {
for (const model of allowedModels.value) {
mapping[model] = model
}
} else {
for (const m of modelMappings.value) {
const from = m.from.trim()
const to = m.to.trim()
if (from && to) {
mapping[from] = to
}
}
}
return Object.keys(mapping).length > 0 ? mapping : null
return buildModelMappingPayload(
modelRestrictionMode.value,
allowedModels.value,
modelMappings.value
)
}
const buildUpdatePayload = (): Record<string, unknown> | null => {

View File

@@ -916,8 +916,8 @@
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
</div>
<!-- Model Restriction Section (不适用于 GeminiAntigravity 已在上层条件排除) -->
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<!-- Model Restriction Section (Antigravity 已在上层条件排除) -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div
@@ -1200,34 +1200,6 @@
</div>
</div>
<!-- Gemini 模型说明 -->
<div v-if="form.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
<div class="flex items-start gap-3">
<svg
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<div>
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
{{ t('admin.accounts.gemini.modelPassthrough') }}
</p>
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
</p>
</div>
</div>
</div>
</div>
</div>
<!-- Temp Unschedulable Rules -->
@@ -1378,9 +1350,9 @@
</div>
</div>
<!-- Intercept Warmup Requests (Anthropic only) -->
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
<div
v-if="form.platform === 'anthropic'"
v-if="form.platform === 'anthropic' || form.platform === 'antigravity'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="flex items-center justify-between">
@@ -2157,7 +2129,7 @@
<ConfirmDialog
:show="showMixedChannelWarning"
:title="t('admin.accounts.mixedChannelWarningTitle')"
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
:message="mixedChannelWarningMessageText"
:confirm-text="t('common.confirm')"
:cancel-text="t('common.cancel')"
:danger="true"
@@ -2189,13 +2161,21 @@ import {
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
import type { Proxy, AdminGroup, AccountPlatform, AccountType } from '@/types'
import type {
Proxy,
AdminGroup,
AccountPlatform,
AccountType,
CheckMixedChannelResponse,
CreateAccountRequest
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
@@ -2337,10 +2317,13 @@ const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false)
// Mixed channel warning dialog state
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
const pendingCreatePayload = ref<any>(null)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
)
const mixedChannelWarningRawMessage = ref('')
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
const antigravityMixedChannelConfirmed = ref(false)
const showAdvancedOAuth = ref(false)
const showGeminiHelpDialog = ref(false)
@@ -2378,6 +2361,13 @@ const isOpenAIModelRestrictionDisabled = computed(() =>
form.platform === 'openai' && openaiPassthroughEnabled.value
)
const mixedChannelWarningMessageText = computed(() => {
if (mixedChannelWarningDetails.value) {
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
}
return mixedChannelWarningRawMessage.value
})
const geminiQuotaDocs = {
codeAssist: 'https://developers.google.com/gemini-code-assist/resources/quotas',
aiStudio: 'https://ai.google.dev/pricing',
@@ -2544,8 +2534,8 @@ watch(
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
}
// Reset Anthropic-specific settings when switching to other platforms
if (newPlatform !== 'anthropic') {
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false
}
if (newPlatform === 'sora') {
@@ -2794,6 +2784,105 @@ const splitTempUnschedKeywords = (value: string) => {
.filter((item) => item.length > 0)
}
const needsMixedChannelCheck = (platform: AccountPlatform) => platform === 'antigravity' || platform === 'anthropic'
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
const details = resp?.details
if (!details) {
return null
}
return {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
}
const clearMixedChannelDialog = () => {
showMixedChannelWarning.value = false
mixedChannelWarningDetails.value = null
mixedChannelWarningRawMessage.value = ''
mixedChannelWarningAction.value = null
}
const openMixedChannelDialog = (opts: {
response?: CheckMixedChannelResponse
message?: string
onConfirm: () => Promise<void>
}) => {
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
mixedChannelWarningRawMessage.value =
opts.message || opts.response?.message || t('admin.accounts.failedToCreate')
mixedChannelWarningAction.value = opts.onConfirm
showMixedChannelWarning.value = true
}
const withAntigravityConfirmFlag = (payload: CreateAccountRequest): CreateAccountRequest => {
if (needsMixedChannelCheck(payload.platform) && antigravityMixedChannelConfirmed.value) {
return {
...payload,
confirm_mixed_channel_risk: true
}
}
const cloned = { ...payload }
delete cloned.confirm_mixed_channel_risk
return cloned
}
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
if (!needsMixedChannelCheck(form.platform)) {
return true
}
if (antigravityMixedChannelConfirmed.value) {
return true
}
try {
const result = await adminAPI.accounts.checkMixedChannelRisk({
platform: form.platform,
group_ids: form.group_ids
})
if (!result.has_risk) {
return true
}
openMixedChannelDialog({
response: result,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await onConfirm()
}
})
return false
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
return false
}
}
const submitCreateAccount = async (payload: CreateAccountRequest) => {
submitting.value = true
try {
await adminAPI.accounts.create(withAntigravityConfirmFlag(payload))
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
} catch (error: any) {
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck(form.platform)) {
openMixedChannelDialog({
message: error.response?.data?.message,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await submitCreateAccount(payload)
}
})
return
}
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
submitting.value = false
}
}
// Methods
const resetForm = () => {
step.value = 1
@@ -2855,9 +2944,13 @@ const resetForm = () => {
geminiOAuth.resetState()
antigravityOAuth.resetState()
oauthFlowRef.value?.reset()
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
}
const handleClose = () => {
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
emit('close')
}
@@ -2916,56 +3009,34 @@ const buildSoraExtra = (
}
// Helper function to create account with mixed channel warning handling
const doCreateAccount = async (payload: any) => {
const doCreateAccount = async (payload: CreateAccountRequest) => {
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
await submitCreateAccount(payload)
})
if (!canContinue) {
return
}
await submitCreateAccount(payload)
}
// Handle mixed channel warning confirmation
const handleMixedChannelConfirm = async () => {
const action = mixedChannelWarningAction.value
if (!action) {
clearMixedChannelDialog()
return
}
clearMixedChannelDialog()
submitting.value = true
try {
await adminAPI.accounts.create(payload)
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
} catch (error: any) {
// Handle 409 mixed_channel_warning - show confirmation dialog
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
const details = error.response.data.details || {}
mixedChannelWarningDetails.value = {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
pendingCreatePayload.value = payload
showMixedChannelWarning.value = true
} else {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
}
await action()
} finally {
submitting.value = false
}
}
// Handle mixed channel warning confirmation
const handleMixedChannelConfirm = async () => {
showMixedChannelWarning.value = false
if (pendingCreatePayload.value) {
pendingCreatePayload.value.confirm_mixed_channel_risk = true
submitting.value = true
try {
await adminAPI.accounts.create(pendingCreatePayload.value)
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
submitting.value = false
pendingCreatePayload.value = null
}
}
}
const handleMixedChannelCancel = () => {
showMixedChannelWarning.value = false
pendingCreatePayload.value = null
mixedChannelWarningDetails.value = null
clearMixedChannelDialog()
}
const handleSubmit = async () => {
@@ -2975,6 +3046,12 @@ const handleSubmit = async () => {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
step.value = 2
})
if (!canContinue) {
return
}
step.value = 2
return
}
@@ -3010,15 +3087,10 @@ const handleSubmit = async () => {
credentials.model_mapping = antigravityModelMapping
}
submitting.value = true
try {
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
submitting.value = false
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
return
}
@@ -3059,10 +3131,7 @@ const handleSubmit = async () => {
credentials.custom_error_codes = [...selectedErrorCodes.value]
}
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
credentials.intercept_warmup_requests = true
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
if (!applyTempUnschedConfig(credentials)) {
return
}
@@ -3132,7 +3201,7 @@ const createAccountAndFinish = async (
if (!applyTempUnschedConfig(credentials)) {
return
}
await adminAPI.accounts.create({
await doCreateAccount({
name: form.name,
notes: form.notes,
platform,
@@ -3147,9 +3216,6 @@ const createAccountAndFinish = async (
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
}
// OpenAI OAuth 授权码兑换
@@ -3497,7 +3563,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
await adminAPI.accounts.create({
const createPayload = withAntigravityConfirmFlag({
name: accountName,
notes: form.notes,
platform: 'antigravity',
@@ -3512,6 +3578,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
await adminAPI.accounts.create(createPayload)
successCount++
} catch (error: any) {
failedCount++
@@ -3606,6 +3673,7 @@ const handleAntigravityExchange = async (authCode: string) => {
if (!tokenInfo) return
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
// Antigravity 只使用映射模式
const antigravityModelMapping = buildModelMappingObject(
'mapping',
@@ -3677,10 +3745,8 @@ const handleAnthropicExchange = async (authCode: string) => {
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
}
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
const credentials: Record<string, unknown> = { ...tokenInfo }
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
} catch (error: any) {
oauth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
@@ -3779,11 +3845,8 @@ const handleCookieAuth = async (sessionKey: string) => {
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials
const credentials: Record<string, unknown> = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
const credentials: Record<string, unknown> = { ...tokenInfo }
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
if (tempUnschedEnabled.value) {
credentials.temp_unschedulable_enabled = true
credentials.temp_unschedulable_rules = tempUnschedPayload

View File

@@ -65,8 +65,8 @@
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
</div>
<!-- Model Restriction Section (不适用于 Gemini Antigravity) -->
<div v-if="account.platform !== 'gemini' && account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<!-- Model Restriction Section (不适用于 Antigravity) -->
<div v-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div
@@ -349,34 +349,6 @@
</div>
</div>
<!-- Gemini 模型说明 -->
<div v-if="account.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
<div class="flex items-start gap-3">
<svg
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<div>
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
{{ t('admin.accounts.gemini.modelPassthrough') }}
</p>
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
</p>
</div>
</div>
</div>
</div>
</div>
<!-- Upstream fields (only for upstream type) -->
@@ -641,9 +613,9 @@
</div>
</div>
<!-- Intercept Warmup Requests (Anthropic only) -->
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
<div
v-if="account?.platform === 'anthropic'"
v-if="account?.platform === 'anthropic' || account?.platform === 'antigravity'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="flex items-center justify-between">
@@ -1139,7 +1111,7 @@
<ConfirmDialog
:show="showMixedChannelWarning"
:title="t('admin.accounts.mixedChannelWarningTitle')"
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
:message="mixedChannelWarningMessageText"
:confirm-text="t('common.confirm')"
:cancel-text="t('common.cancel')"
:danger="true"
@@ -1154,7 +1126,7 @@ import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import type { Account, Proxy, AdminGroup } from '@/types'
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
@@ -1162,6 +1134,7 @@ import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
@@ -1233,10 +1206,13 @@ const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-mod
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
// Mixed channel warning dialog state
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
const pendingUpdatePayload = ref<Record<string, unknown> | null>(null)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
)
const mixedChannelWarningRawMessage = ref('')
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
const antigravityMixedChannelConfirmed = ref(false)
// Quota control state (Anthropic OAuth/SetupToken only)
const windowCostEnabled = ref(false)
@@ -1297,6 +1273,13 @@ const defaultBaseUrl = computed(() => {
return 'https://api.anthropic.com'
})
const mixedChannelWarningMessageText = computed(() => {
if (mixedChannelWarningDetails.value) {
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
}
return mixedChannelWarningRawMessage.value
})
const form = reactive({
name: '',
notes: '',
@@ -1326,6 +1309,11 @@ watch(
() => props.account,
(newAccount) => {
if (newAccount) {
antigravityMixedChannelConfirmed.value = false
showMixedChannelWarning.value = false
mixedChannelWarningDetails.value = null
mixedChannelWarningRawMessage.value = ''
mixedChannelWarningAction.value = null
form.name = newAccount.name
form.notes = newAccount.notes || ''
form.proxy_id = newAccount.proxy_id
@@ -1725,18 +1713,123 @@ function toPositiveNumber(value: unknown) {
return Math.trunc(num)
}
const needsMixedChannelCheck = () => props.account?.platform === 'antigravity' || props.account?.platform === 'anthropic'
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
const details = resp?.details
if (!details) {
return null
}
return {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
}
const clearMixedChannelDialog = () => {
showMixedChannelWarning.value = false
mixedChannelWarningDetails.value = null
mixedChannelWarningRawMessage.value = ''
mixedChannelWarningAction.value = null
}
const openMixedChannelDialog = (opts: {
response?: CheckMixedChannelResponse
message?: string
onConfirm: () => Promise<void>
}) => {
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
mixedChannelWarningRawMessage.value =
opts.message || opts.response?.message || t('admin.accounts.failedToUpdate')
mixedChannelWarningAction.value = opts.onConfirm
showMixedChannelWarning.value = true
}
const withAntigravityConfirmFlag = (payload: Record<string, unknown>) => {
if (needsMixedChannelCheck() && antigravityMixedChannelConfirmed.value) {
return {
...payload,
confirm_mixed_channel_risk: true
}
}
const cloned = { ...payload }
delete cloned.confirm_mixed_channel_risk
return cloned
}
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
if (!needsMixedChannelCheck()) {
return true
}
if (antigravityMixedChannelConfirmed.value) {
return true
}
if (!props.account) {
return false
}
try {
const result = await adminAPI.accounts.checkMixedChannelRisk({
platform: props.account.platform,
group_ids: form.group_ids,
account_id: props.account.id
})
if (!result.has_risk) {
return true
}
openMixedChannelDialog({
response: result,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await onConfirm()
}
})
return false
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
return false
}
}
const formatDateTimeLocal = formatDateTimeLocalInput
const parseDateTimeLocal = parseDateTimeLocalInput
// Methods
const handleClose = () => {
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
emit('close')
}
const submitUpdateAccount = async (accountID: number, updatePayload: Record<string, unknown>) => {
submitting.value = true
try {
const updatedAccount = await adminAPI.accounts.update(accountID, withAntigravityConfirmFlag(updatePayload))
appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated', updatedAccount)
handleClose()
} catch (error: any) {
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck()) {
openMixedChannelDialog({
message: error.response?.data?.message,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await submitUpdateAccount(accountID, updatePayload)
}
})
return
}
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
} finally {
submitting.value = false
}
}
const handleSubmit = async () => {
if (!props.account) return
const accountID = props.account.id
submitting.value = true
const updatePayload: Record<string, unknown> = { ...form }
try {
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
@@ -1768,7 +1861,6 @@ const handleSubmit = async () => {
newCredentials.api_key = currentCredentials.api_key
} else {
appStore.showError(t('admin.accounts.apiKeyIsRequired'))
submitting.value = false
return
}
@@ -1789,11 +1881,8 @@ const handleSubmit = async () => {
}
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
@@ -1808,8 +1897,10 @@ const handleSubmit = async () => {
newCredentials.api_key = editApiKey.value.trim()
}
// Add intercept warmup requests setting
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
@@ -1819,13 +1910,8 @@ const handleSubmit = async () => {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
} else {
delete newCredentials.intercept_warmup_requests
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
@@ -1955,52 +2041,36 @@ const handleSubmit = async () => {
updatePayload.extra = newExtra
}
const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated', updatedAccount)
handleClose()
} catch (error: any) {
// Handle 409 mixed_channel_warning - show confirmation dialog
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
const details = error.response.data.details || {}
mixedChannelWarningDetails.value = {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
pendingUpdatePayload.value = updatePayload
showMixedChannelWarning.value = true
} else {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
await submitUpdateAccount(accountID, updatePayload)
})
if (!canContinue) {
return
}
} finally {
submitting.value = false
await submitUpdateAccount(accountID, updatePayload)
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
}
}
// Handle mixed channel warning confirmation
const handleMixedChannelConfirm = async () => {
showMixedChannelWarning.value = false
if (pendingUpdatePayload.value && props.account) {
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
submitting.value = true
try {
const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated', updatedAccount)
handleClose()
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
} finally {
submitting.value = false
pendingUpdatePayload.value = null
}
const action = mixedChannelWarningAction.value
if (!action) {
clearMixedChannelDialog()
return
}
clearMixedChannelDialog()
submitting.value = true
try {
await action()
} finally {
submitting.value = false
}
}
const handleMixedChannelCancel = () => {
showMixedChannelWarning.value = false
pendingUpdatePayload.value = null
mixedChannelWarningDetails.value = null
clearMixedChannelDialog()
}
</script>

View File

@@ -0,0 +1,46 @@
import { describe, it, expect } from 'vitest'
import { applyInterceptWarmup } from '../credentialsBuilder'
describe('applyInterceptWarmup', () => {
it('create + enabled=true: should set intercept_warmup_requests to true', () => {
const creds: Record<string, unknown> = { access_token: 'tok' }
applyInterceptWarmup(creds, true, 'create')
expect(creds.intercept_warmup_requests).toBe(true)
})
it('create + enabled=false: should not add the field', () => {
const creds: Record<string, unknown> = { access_token: 'tok' }
applyInterceptWarmup(creds, false, 'create')
expect('intercept_warmup_requests' in creds).toBe(false)
})
it('edit + enabled=true: should set intercept_warmup_requests to true', () => {
const creds: Record<string, unknown> = { api_key: 'sk' }
applyInterceptWarmup(creds, true, 'edit')
expect(creds.intercept_warmup_requests).toBe(true)
})
it('edit + enabled=false + field exists: should delete the field', () => {
const creds: Record<string, unknown> = { api_key: 'sk', intercept_warmup_requests: true }
applyInterceptWarmup(creds, false, 'edit')
expect('intercept_warmup_requests' in creds).toBe(false)
})
it('edit + enabled=false + field absent: should not throw', () => {
const creds: Record<string, unknown> = { api_key: 'sk' }
applyInterceptWarmup(creds, false, 'edit')
expect('intercept_warmup_requests' in creds).toBe(false)
})
it('should not affect other fields', () => {
const creds: Record<string, unknown> = {
api_key: 'sk',
base_url: 'url',
intercept_warmup_requests: true
}
applyInterceptWarmup(creds, false, 'edit')
expect(creds.api_key).toBe('sk')
expect(creds.base_url).toBe('url')
expect('intercept_warmup_requests' in creds).toBe(false)
})
})

View File

@@ -0,0 +1,11 @@
export function applyInterceptWarmup(
credentials: Record<string, unknown>,
enabled: boolean,
mode: 'create' | 'edit'
): void {
if (enabled) {
credentials.intercept_warmup_requests = true
} else if (mode === 'edit') {
delete credentials.intercept_warmup_requests
}
}

View File

@@ -76,6 +76,7 @@ const antigravityModels = [
// Claude 4.5+ 系列
'claude-opus-4-6',
'claude-opus-4-5-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-5',
'claude-sonnet-4-5-thinking',
// Gemini 2.5 系列
@@ -88,6 +89,9 @@ const antigravityModels = [
'gemini-3-pro-high',
'gemini-3-pro-low',
'gemini-3-pro-image',
// Gemini 3.1 系列
'gemini-3.1-pro-high',
'gemini-3.1-pro-low',
// 其他
'gpt-oss-120b-medium',
'tab_flash_lite_preview'
@@ -301,6 +305,7 @@ const antigravityPresetMappings = [
{ label: '3-Flash透传', from: 'gemini-3-flash', to: 'gemini-3-flash', color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' },
{ label: '2.5-Flash-Lite透传', from: 'gemini-2.5-flash-lite', to: 'gemini-2.5-flash-lite', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
// 精确映射
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
]

View File

@@ -2047,7 +2047,7 @@ export default {
gemini3Pro: 'G3P',
gemini3Flash: 'G3F',
gemini3Image: 'G3I',
claude45: 'C4.5'
claude: 'Claude'
},
tier: {
free: 'Free',

View File

@@ -1583,7 +1583,7 @@ export default {
gemini3Pro: 'G3P',
gemini3Flash: 'G3F',
gemini3Image: 'G3I',
claude45: 'C4.5'
claude: 'Claude'
},
tier: {
free: 'Free',

View File

@@ -581,6 +581,7 @@ export interface GeminiCredentials {
token_type?: string
scope?: string
expires_at?: string
model_mapping?: Record<string, string>
}
export interface TempUnschedulableRule {
@@ -766,6 +767,26 @@ export interface UpdateAccountRequest {
confirm_mixed_channel_risk?: boolean
}
export interface CheckMixedChannelRequest {
platform: AccountPlatform
group_ids: number[]
account_id?: number
}
export interface MixedChannelWarningDetails {
group_id: number
group_name: string
current_platform: string
other_platform: string
}
export interface CheckMixedChannelResponse {
has_risk: boolean
error?: string
message?: string
details?: MixedChannelWarningDetails
}
export interface CreateProxyRequest {
name: string
protocol: ProxyProtocol

View File

@@ -1,18 +1,13 @@
import { defineConfig } from 'vitest/config'
import vue from '@vitejs/plugin-vue'
import { resolve } from 'path'
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': resolve(__dirname, 'src'),
'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js'
}
},
define: {
__INTLIFY_JIT_COMPILATION__: true
},
test: {
globals: true,
environment: 'jsdom',
@@ -37,8 +32,6 @@ export default defineConfig({
lines: 80
}
}
},
setupFiles: ['./src/__tests__/setup.ts'],
testTimeout: 10000
}
}
})