From fdf9f68298eefc16e682a3efd8d4ac59ae156239 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 16:44:18 +0800 Subject: [PATCH 01/24] fix: update Claude usage window to support 4.6 models The usage progress bar only matched claude-sonnet-4-5 and claude-opus-4-5-thinking. After upgrading to 4.6, the backend returns claude-sonnet-4-6/claude-opus-4-6-thinking which didn't match, causing the Claude usage bar to not display. - Add claude-sonnet-4-6 and claude-opus-4-6-thinking to the match list - Rename label from "C4.5" to "Claude" for future-proofing --- .../components/account/AccountUsageCell.vue | 19 +++++++++++-------- frontend/src/i18n/locales/en.ts | 2 +- frontend/src/i18n/locales/zh.ts | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index cada94c6..b47b4115 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -172,12 +172,12 @@ color="purple" /> - + @@ -400,9 +400,12 @@ const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI( // Gemini 3 Image from API const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image'])) -// Claude 4.5 from API -const antigravityClaude45UsageFromAPI = computed(() => - getAntigravityUsageFromAPI(['claude-sonnet-4-5', 'claude-opus-4-5-thinking']) +// Claude from API (all Claude model variants) +const antigravityClaudeUsageFromAPI = computed(() => + getAntigravityUsageFromAPI([ + 'claude-sonnet-4-5', 'claude-opus-4-5-thinking', + 'claude-sonnet-4-6', 'claude-opus-4-6-thinking', + ]) ) // Antigravity 账户类型(从 load_code_assist 响应中提取) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 3c415989..cdd9ad19 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2047,7 +2047,7 @@ export default { gemini3Pro: 'G3P', gemini3Flash: 'G3F', gemini3Image: 'G3I', - claude45: 'C4.5' + claude: 'Claude' }, tier: { free: 'Free', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 770f9ca9..8ef50267 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1583,7 +1583,7 @@ export default { gemini3Pro: 'G3P', gemini3Flash: 'G3F', gemini3Image: 'G3I', - claude45: 'C4.5' + claude: 'Claude' }, tier: { free: 'Free', From 0dacdf480b485db3dfc7faed4d7bed5de0c9f84c Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 16:45:08 +0800 Subject: [PATCH 02/24] fix: distinguish client disconnection from upstream retry failure Before this change, when a client disconnected mid-request, the error message was "Upstream request failed after retries", which is misleading and pollutes error logs. Now we check context.Err() to return a more accurate "Client disconnected" message for both Claude and Gemini forward paths. --- backend/internal/service/antigravity_gateway_service.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index cf87b282..26b14e68 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 + if c.Request.Context().Err() != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response") + } return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } resp := result.resp @@ -2042,6 +2046,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 + if c.Request.Context().Err() != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response") + } return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } resp := result.resp From 59898c16c697193c1ac8d8b051c7967f19e5e03a Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 16:48:16 +0800 Subject: [PATCH 03/24] fix: fix intercept_warmup_requests config not being saved Extract applyInterceptWarmup utility to unify all credential building call sites: - Fix upstream account creation missing intercept_warmup_requests write - Fix apikey edit mode missing else-branch to clear the setting - Add backend unit test for IsInterceptWarmupEnabled - Add frontend unit test for credentialsBuilder --- .../service/account_intercept_warmup_test.go | 66 +++++++++++++++++++ .../components/account/CreateAccountModal.vue | 22 +++---- .../components/account/EditAccountModal.vue | 14 ++-- .../__tests__/credentialsBuilder.spec.ts | 46 +++++++++++++ .../components/account/credentialsBuilder.ts | 11 ++++ 5 files changed, 138 insertions(+), 21 deletions(-) create mode 100644 backend/internal/service/account_intercept_warmup_test.go create mode 100644 frontend/src/components/account/__tests__/credentialsBuilder.spec.ts create mode 100644 frontend/src/components/account/credentialsBuilder.ts diff --git a/backend/internal/service/account_intercept_warmup_test.go b/backend/internal/service/account_intercept_warmup_test.go new file mode 100644 index 00000000..f117fd8d --- /dev/null +++ b/backend/internal/service/account_intercept_warmup_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsInterceptWarmupEnabled(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + expected bool + }{ + { + name: "nil credentials", + credentials: nil, + expected: false, + }, + { + name: "empty map", + credentials: map[string]any{}, + expected: false, + }, + { + name: "field not present", + credentials: map[string]any{"access_token": "tok"}, + expected: false, + }, + { + name: "field is true", + credentials: map[string]any{"intercept_warmup_requests": true}, + expected: true, + }, + { + name: "field is false", + credentials: map[string]any{"intercept_warmup_requests": false}, + expected: false, + }, + { + name: "field is string true", + credentials: map[string]any{"intercept_warmup_requests": "true"}, + expected: false, + }, + { + name: "field is int 1", + credentials: map[string]any{"intercept_warmup_requests": 1}, + expected: false, + }, + { + name: "field is nil", + credentials: map[string]any{"intercept_warmup_requests": nil}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Credentials: tt.credentials} + result := a.IsInterceptWarmupEnabled() + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 30da0767..64253447 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2196,6 +2196,7 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' @@ -3010,6 +3011,8 @@ const handleSubmit = async () => { credentials.model_mapping = antigravityModelMapping } + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') + submitting.value = true try { const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined @@ -3059,10 +3062,7 @@ const handleSubmit = async () => { credentials.custom_error_codes = [...selectedErrorCodes.value] } - // Add intercept warmup requests setting - if (interceptWarmupRequests.value) { - credentials.intercept_warmup_requests = true - } + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') if (!applyTempUnschedConfig(credentials)) { return } @@ -3606,6 +3606,7 @@ const handleAntigravityExchange = async (authCode: string) => { if (!tokenInfo) return const credentials = antigravityOAuth.buildCredentials(tokenInfo) + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') // Antigravity 只使用映射模式 const antigravityModelMapping = buildModelMappingObject( 'mapping', @@ -3677,10 +3678,8 @@ const handleAnthropicExchange = async (authCode: string) => { extra.cache_ttl_override_target = cacheTTLOverrideTarget.value } - const credentials = { - ...tokenInfo, - ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) - } + const credentials: Record = { ...tokenInfo } + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra) } catch (error: any) { oauth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') @@ -3779,11 +3778,8 @@ const handleCookieAuth = async (sessionKey: string) => { const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name - // Merge interceptWarmupRequests into credentials - const credentials: Record = { - ...tokenInfo, - ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) - } + const credentials: Record = { ...tokenInfo } + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') if (tempUnschedEnabled.value) { credentials.temp_unschedulable_enabled = true credentials.temp_unschedulable_rules = tempUnschedPayload diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index c6643717..3d75df7d 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1162,6 +1162,7 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { @@ -1789,9 +1790,7 @@ const handleSubmit = async () => { } // Add intercept warmup requests setting - if (interceptWarmupRequests.value) { - newCredentials.intercept_warmup_requests = true - } + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') if (!applyTempUnschedConfig(newCredentials)) { submitting.value = false return @@ -1808,6 +1807,9 @@ const handleSubmit = async () => { newCredentials.api_key = editApiKey.value.trim() } + // Add intercept warmup requests setting + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { submitting.value = false return @@ -1819,11 +1821,7 @@ const handleSubmit = async () => { const currentCredentials = (props.account.credentials as Record) || {} const newCredentials: Record = { ...currentCredentials } - if (interceptWarmupRequests.value) { - newCredentials.intercept_warmup_requests = true - } else { - delete newCredentials.intercept_warmup_requests - } + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') if (!applyTempUnschedConfig(newCredentials)) { submitting.value = false return diff --git a/frontend/src/components/account/__tests__/credentialsBuilder.spec.ts b/frontend/src/components/account/__tests__/credentialsBuilder.spec.ts new file mode 100644 index 00000000..be2a8d52 --- /dev/null +++ b/frontend/src/components/account/__tests__/credentialsBuilder.spec.ts @@ -0,0 +1,46 @@ +import { describe, it, expect } from 'vitest' +import { applyInterceptWarmup } from '../credentialsBuilder' + +describe('applyInterceptWarmup', () => { + it('create + enabled=true: should set intercept_warmup_requests to true', () => { + const creds: Record = { access_token: 'tok' } + applyInterceptWarmup(creds, true, 'create') + expect(creds.intercept_warmup_requests).toBe(true) + }) + + it('create + enabled=false: should not add the field', () => { + const creds: Record = { access_token: 'tok' } + applyInterceptWarmup(creds, false, 'create') + expect('intercept_warmup_requests' in creds).toBe(false) + }) + + it('edit + enabled=true: should set intercept_warmup_requests to true', () => { + const creds: Record = { api_key: 'sk' } + applyInterceptWarmup(creds, true, 'edit') + expect(creds.intercept_warmup_requests).toBe(true) + }) + + it('edit + enabled=false + field exists: should delete the field', () => { + const creds: Record = { api_key: 'sk', intercept_warmup_requests: true } + applyInterceptWarmup(creds, false, 'edit') + expect('intercept_warmup_requests' in creds).toBe(false) + }) + + it('edit + enabled=false + field absent: should not throw', () => { + const creds: Record = { api_key: 'sk' } + applyInterceptWarmup(creds, false, 'edit') + expect('intercept_warmup_requests' in creds).toBe(false) + }) + + it('should not affect other fields', () => { + const creds: Record = { + api_key: 'sk', + base_url: 'url', + intercept_warmup_requests: true + } + applyInterceptWarmup(creds, false, 'edit') + expect(creds.api_key).toBe('sk') + expect(creds.base_url).toBe('url') + expect('intercept_warmup_requests' in creds).toBe(false) + }) +}) diff --git a/frontend/src/components/account/credentialsBuilder.ts b/frontend/src/components/account/credentialsBuilder.ts new file mode 100644 index 00000000..b8008e8b --- /dev/null +++ b/frontend/src/components/account/credentialsBuilder.ts @@ -0,0 +1,11 @@ +export function applyInterceptWarmup( + credentials: Record, + enabled: boolean, + mode: 'create' | 'edit' +): void { + if (enabled) { + credentials.intercept_warmup_requests = true + } else if (mode === 'edit') { + delete credentials.intercept_warmup_requests + } +} From aaac1aaca96dce1600827b68f0a3b918ed0e59d8 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 17:11:14 +0800 Subject: [PATCH 04/24] feat: add mixed-channel precheck API for account-group binding Add a dedicated CheckMixedChannel endpoint that allows the frontend to pre-validate mixed channel risk before submitting create/update requests. This improves UX by showing warnings earlier in the flow instead of only after form submission. Backend changes: - Add CheckMixedChannelRequest struct and CheckMixedChannel handler - Register POST /check-mixed-channel route - Expose CheckMixedChannelRisk as public method on AdminService - Simplify Create/Update 409 responses (remove details/require_confirmation) - Add comprehensive handler tests and stub methods Frontend changes: - Add checkMixedChannelRisk API function and TypeScript types - Refactor CreateAccountModal to precheck before step transition and submission - Refactor EditAccountModal to precheck before update submission - Replace pendingPayload pattern with action-based dialog flow --- .../internal/handler/admin/account_handler.go | 69 ++++-- .../account_handler_mixed_channel_test.go | 147 +++++++++++++ .../handler/admin/admin_service_stub_test.go | 47 ++-- backend/internal/server/routes/admin.go | 1 + backend/internal/service/admin_service.go | 6 + frontend/src/api/admin/accounts.ts | 15 +- .../components/account/CreateAccountModal.vue | 206 +++++++++++++----- .../components/account/EditAccountModal.vue | 194 +++++++++++++---- frontend/src/types/index.ts | 20 ++ 9 files changed, 576 insertions(+), 129 deletions(-) create mode 100644 backend/internal/handler/admin/account_handler_mixed_channel_test.go diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index a2a8dd43..df82476c 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct { ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } +// CheckMixedChannelRequest represents check mixed channel risk request +type CheckMixedChannelRequest struct { + Platform string `json:"platform" binding:"required"` + GroupIDs []int64 `json:"group_ids"` + AccountID *int64 `json:"account_id"` +} + // AccountWithConcurrency extends Account with real-time concurrency info type AccountWithConcurrency struct { *dto.Account @@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// CheckMixedChannel handles checking mixed channel risk for account-group binding. +// POST /api/v1/admin/accounts/check-mixed-channel +func (h *AccountHandler) CheckMixedChannel(c *gin.Context) { + var req CheckMixedChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.GroupIDs) == 0 { + response.Success(c, gin.H{"has_risk": false}) + return + } + + accountID := int64(0) + if req.AccountID != nil { + accountID = *req.AccountID + } + + err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs) + if err != nil { + var mixedErr *service.MixedChannelError + if errors.As(err, &mixedErr) { + response.Success(c, gin.H{ + "has_risk": true, + "error": "mixed_channel_warning", + "message": mixedErr.Error(), + "details": gin.H{ + "group_id": mixedErr.GroupID, + "group_name": mixedErr.GroupName, + "current_platform": mixedErr.CurrentPlatform, + "other_platform": mixedErr.OtherPlatform, + }, + }) + return + } + + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"has_risk": false}) +} + // Create handles creating a new account // POST /api/v1/admin/accounts func (h *AccountHandler) Create(c *gin.Context) { @@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) { // 检查是否为混合渠道错误 var mixedErr *service.MixedChannelError if errors.As(err, &mixedErr) { - // 返回特殊错误码要求确认 + // 创建接口仅返回最小必要字段,详细信息由专门检查接口提供 c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), - "details": gin.H{ - "group_id": mixedErr.GroupID, - "group_name": mixedErr.GroupName, - "current_platform": mixedErr.CurrentPlatform, - "other_platform": mixedErr.OtherPlatform, - }, - "require_confirmation": true, }) return } @@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) { // 检查是否为混合渠道错误 var mixedErr *service.MixedChannelError if errors.As(err, &mixedErr) { - // 返回特殊错误码要求确认 + // 更新接口仅返回最小必要字段,详细信息由专门检查接口提供 c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), - "details": gin.H{ - "group_id": mixedErr.GroupID, - "group_name": mixedErr.GroupName, - "current_platform": mixedErr.CurrentPlatform, - "other_platform": mixedErr.OtherPlatform, - }, - "require_confirmation": true, }) return } diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go new file mode 100644 index 00000000..ad004844 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -0,0 +1,147 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel) + router.POST("/api/v1/admin/accounts", accountHandler.Create) + router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update) + return router +} + +func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "platform": "antigravity", + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, false, data["has_risk"]) + require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID) + require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform) + require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs) +} + +func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.checkMixedErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "platform": "antigravity", + "group_ids": []int64{27}, + "account_id": 99, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, data["has_risk"]) + require.Equal(t, "mixed_channel_warning", data["error"]) + details, ok := data["details"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(27), details["group_id"]) + require.Equal(t, "claude-max", details["group_name"]) + require.Equal(t, "Antigravity", details["current_platform"]) + require.Equal(t, "Anthropic", details["other_platform"]) + require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID) +} + +func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.createAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "name": "ag-oauth-1", + "platform": "antigravity", + "type": "oauth", + "credentials": map[string]any{"refresh_token": "rt"}, + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "mixed_channel_warning") + _, hasDetails := resp["details"] + _, hasRequireConfirmation := resp["require_confirmation"] + require.False(t, hasDetails) + require.False(t, hasRequireConfirmation) +} + +func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.updateAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "mixed_channel_warning") + _, hasDetails := resp["details"] + _, hasRequireConfirmation := resp["require_confirmation"] + require.False(t, hasDetails) + require.False(t, hasRequireConfirmation) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 9f3dcf80..848122e4 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -10,19 +10,27 @@ import ( ) type stubAdminService struct { - users []service.User - apiKeys []service.APIKey - groups []service.Group - accounts []service.Account - proxies []service.Proxy - proxyCounts []service.ProxyWithAccountCount - redeems []service.RedeemCode - createdAccounts []*service.CreateAccountInput - createdProxies []*service.CreateProxyInput - updatedProxyIDs []int64 - updatedProxies []*service.UpdateProxyInput - testedProxyIDs []int64 - mu sync.Mutex + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode + createdAccounts []*service.CreateAccountInput + createdProxies []*service.CreateProxyInput + updatedProxyIDs []int64 + updatedProxies []*service.UpdateProxyInput + testedProxyIDs []int64 + createAccountErr error + updateAccountErr error + checkMixedErr error + lastMixedCheck struct { + accountID int64 + platform string + groupIDs []int64 + } + mu sync.Mutex } func newStubAdminService() *stubAdminService { @@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre s.mu.Lock() s.createdAccounts = append(s.createdAccounts, input) s.mu.Unlock() + if s.createAccountErr != nil { + return nil, s.createAccountErr + } account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} return &account, nil } func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + if s.updateAccountErr != nil { + return nil, s.updateAccountErr + } account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive} return &account, nil } @@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil } +func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + s.lastMixedCheck.accountID = currentAccountID + s.lastMixedCheck.platform = currentAccountPlatform + s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...) + return s.checkMixedErr +} + func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { search = strings.TrimSpace(strings.ToLower(search)) filtered := make([]service.Proxy, 0, len(s.proxies)) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 4b4d97c3..36efacc8 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("", h.Admin.Account.List) accounts.GET("/:id", h.Admin.Account.GetByID) accounts.POST("", h.Admin.Account.Create) + accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS) accounts.PUT("/:id", h.Admin.Account.Update) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 8614f24a..47339661 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -54,6 +54,7 @@ type AdminService interface { SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) + CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error // Proxy management ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) @@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } +// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. +func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) +} + func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { if s.proxyLatencyCache == nil || len(proxies) == 0 { return diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 89b11783..1b8ae9ad 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -15,7 +15,9 @@ import type { AccountUsageStatsResponse, TempUnschedulableStatus, AdminDataPayload, - AdminDataImportResult + AdminDataImportResult, + CheckMixedChannelRequest, + CheckMixedChannelResponse } from '@/types' /** @@ -133,6 +135,16 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise return data } +/** + * Check mixed-channel risk for account-group binding. + */ +export async function checkMixedChannelRisk( + payload: CheckMixedChannelRequest +): Promise { + const { data } = await apiClient.post('/admin/accounts/check-mixed-channel', payload) + return data +} + /** * Delete account * @param id - Account ID @@ -535,6 +547,7 @@ export const accountsAPI = { getById, create, update, + checkMixedChannelRisk, delete: deleteAccount, toggleStatus, testAccount, diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 64253447..83b65159 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2157,7 +2157,7 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one') const geminiAIStudioOAuthEnabled = ref(false) -// Mixed channel warning dialog state const showMixedChannelWarning = ref(false) -const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null) -const pendingCreatePayload = ref(null) +const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>( + null +) +const mixedChannelWarningRawMessage = ref('') +const mixedChannelWarningAction = ref<(() => Promise) | null>(null) +const antigravityMixedChannelConfirmed = ref(false) const showAdvancedOAuth = ref(false) const showGeminiHelpDialog = ref(false) @@ -2379,6 +2389,13 @@ const isOpenAIModelRestrictionDisabled = computed(() => form.platform === 'openai' && openaiPassthroughEnabled.value ) +const mixedChannelWarningMessageText = computed(() => { + if (mixedChannelWarningDetails.value) { + return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value) + } + return mixedChannelWarningRawMessage.value +}) + const geminiQuotaDocs = { codeAssist: 'https://developers.google.com/gemini-code-assist/resources/quotas', aiStudio: 'https://ai.google.dev/pricing', @@ -2795,6 +2812,105 @@ const splitTempUnschedKeywords = (value: string) => { .filter((item) => item.length > 0) } +const needsMixedChannelCheck = (platform: AccountPlatform) => platform === 'antigravity' || platform === 'anthropic' + +const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => { + const details = resp?.details + if (!details) { + return null + } + return { + groupName: details.group_name || 'Unknown', + currentPlatform: details.current_platform || 'Unknown', + otherPlatform: details.other_platform || 'Unknown' + } +} + +const clearMixedChannelDialog = () => { + showMixedChannelWarning.value = false + mixedChannelWarningDetails.value = null + mixedChannelWarningRawMessage.value = '' + mixedChannelWarningAction.value = null +} + +const openMixedChannelDialog = (opts: { + response?: CheckMixedChannelResponse + message?: string + onConfirm: () => Promise +}) => { + mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response) + mixedChannelWarningRawMessage.value = + opts.message || opts.response?.message || t('admin.accounts.failedToCreate') + mixedChannelWarningAction.value = opts.onConfirm + showMixedChannelWarning.value = true +} + +const withAntigravityConfirmFlag = (payload: CreateAccountRequest): CreateAccountRequest => { + if (needsMixedChannelCheck(payload.platform) && antigravityMixedChannelConfirmed.value) { + return { + ...payload, + confirm_mixed_channel_risk: true + } + } + const cloned = { ...payload } + delete cloned.confirm_mixed_channel_risk + return cloned +} + +const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise): Promise => { + if (!needsMixedChannelCheck(form.platform)) { + return true + } + if (antigravityMixedChannelConfirmed.value) { + return true + } + + try { + const result = await adminAPI.accounts.checkMixedChannelRisk({ + platform: form.platform, + group_ids: form.group_ids + }) + if (!result.has_risk) { + return true + } + openMixedChannelDialog({ + response: result, + onConfirm: async () => { + antigravityMixedChannelConfirmed.value = true + await onConfirm() + } + }) + return false + } catch (error: any) { + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate')) + return false + } +} + +const submitCreateAccount = async (payload: CreateAccountRequest) => { + submitting.value = true + try { + await adminAPI.accounts.create(withAntigravityConfirmFlag(payload)) + appStore.showSuccess(t('admin.accounts.accountCreated')) + emit('created') + handleClose() + } catch (error: any) { + if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck(form.platform)) { + openMixedChannelDialog({ + message: error.response?.data?.message, + onConfirm: async () => { + antigravityMixedChannelConfirmed.value = true + await submitCreateAccount(payload) + } + }) + return + } + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate')) + } finally { + submitting.value = false + } +} + // Methods const resetForm = () => { step.value = 1 @@ -2856,9 +2972,13 @@ const resetForm = () => { geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() + antigravityMixedChannelConfirmed.value = false + clearMixedChannelDialog() } const handleClose = () => { + antigravityMixedChannelConfirmed.value = false + clearMixedChannelDialog() emit('close') } @@ -2917,56 +3037,34 @@ const buildSoraExtra = ( } // Helper function to create account with mixed channel warning handling -const doCreateAccount = async (payload: any) => { +const doCreateAccount = async (payload: CreateAccountRequest) => { + const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => { + await submitCreateAccount(payload) + }) + if (!canContinue) { + return + } + await submitCreateAccount(payload) +} + +// Handle mixed channel warning confirmation +const handleMixedChannelConfirm = async () => { + const action = mixedChannelWarningAction.value + if (!action) { + clearMixedChannelDialog() + return + } + clearMixedChannelDialog() submitting.value = true try { - await adminAPI.accounts.create(payload) - appStore.showSuccess(t('admin.accounts.accountCreated')) - emit('created') - handleClose() - } catch (error: any) { - // Handle 409 mixed_channel_warning - show confirmation dialog - if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') { - const details = error.response.data.details || {} - mixedChannelWarningDetails.value = { - groupName: details.group_name || 'Unknown', - currentPlatform: details.current_platform || 'Unknown', - otherPlatform: details.other_platform || 'Unknown' - } - pendingCreatePayload.value = payload - showMixedChannelWarning.value = true - } else { - appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) - } + await action() } finally { submitting.value = false } } -// Handle mixed channel warning confirmation -const handleMixedChannelConfirm = async () => { - showMixedChannelWarning.value = false - if (pendingCreatePayload.value) { - pendingCreatePayload.value.confirm_mixed_channel_risk = true - submitting.value = true - try { - await adminAPI.accounts.create(pendingCreatePayload.value) - appStore.showSuccess(t('admin.accounts.accountCreated')) - emit('created') - handleClose() - } catch (error: any) { - appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) - } finally { - submitting.value = false - pendingCreatePayload.value = null - } - } -} - const handleMixedChannelCancel = () => { - showMixedChannelWarning.value = false - pendingCreatePayload.value = null - mixedChannelWarningDetails.value = null + clearMixedChannelDialog() } const handleSubmit = async () => { @@ -2976,6 +3074,12 @@ const handleSubmit = async () => { appStore.showError(t('admin.accounts.pleaseEnterAccountName')) return } + const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => { + step.value = 2 + }) + if (!canContinue) { + return + } step.value = 2 return } @@ -3132,7 +3236,7 @@ const createAccountAndFinish = async ( if (!applyTempUnschedConfig(credentials)) { return } - await adminAPI.accounts.create({ + await doCreateAccount({ name: form.name, notes: form.notes, platform, @@ -3147,9 +3251,6 @@ const createAccountAndFinish = async ( expires_at: form.expires_at, auto_pause_on_expired: autoPauseOnExpired.value }) - appStore.showSuccess(t('admin.accounts.accountCreated')) - emit('created') - handleClose() } // OpenAI OAuth 授权码兑换 @@ -3497,7 +3598,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => { const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name // Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials - await adminAPI.accounts.create({ + const createPayload = withAntigravityConfirmFlag({ name: accountName, notes: form.notes, platform: 'antigravity', @@ -3512,6 +3613,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => { expires_at: form.expires_at, auto_pause_on_expired: autoPauseOnExpired.value }) + await adminAPI.accounts.create(createPayload) successCount++ } catch (error: any) { failedCount++ diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 3d75df7d..c29aa54b 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1139,7 +1139,7 @@ ('edit-mod const getAntigravityModelMappingKey = createStableObjectKeyResolver('edit-antigravity-model-mapping') const getTempUnschedRuleKey = createStableObjectKeyResolver('edit-temp-unsched-rule') -// Mixed channel warning dialog state const showMixedChannelWarning = ref(false) -const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null) -const pendingUpdatePayload = ref | null>(null) +const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>( + null +) +const mixedChannelWarningRawMessage = ref('') +const mixedChannelWarningAction = ref<(() => Promise) | null>(null) +const antigravityMixedChannelConfirmed = ref(false) // Quota control state (Anthropic OAuth/SetupToken only) const windowCostEnabled = ref(false) @@ -1298,6 +1301,13 @@ const defaultBaseUrl = computed(() => { return 'https://api.anthropic.com' }) +const mixedChannelWarningMessageText = computed(() => { + if (mixedChannelWarningDetails.value) { + return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value) + } + return mixedChannelWarningRawMessage.value +}) + const form = reactive({ name: '', notes: '', @@ -1327,6 +1337,11 @@ watch( () => props.account, (newAccount) => { if (newAccount) { + antigravityMixedChannelConfirmed.value = false + showMixedChannelWarning.value = false + mixedChannelWarningDetails.value = null + mixedChannelWarningRawMessage.value = '' + mixedChannelWarningAction.value = null form.name = newAccount.name form.notes = newAccount.notes || '' form.proxy_id = newAccount.proxy_id @@ -1726,18 +1741,123 @@ function toPositiveNumber(value: unknown) { return Math.trunc(num) } +const needsMixedChannelCheck = () => props.account?.platform === 'antigravity' || props.account?.platform === 'anthropic' + +const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => { + const details = resp?.details + if (!details) { + return null + } + return { + groupName: details.group_name || 'Unknown', + currentPlatform: details.current_platform || 'Unknown', + otherPlatform: details.other_platform || 'Unknown' + } +} + +const clearMixedChannelDialog = () => { + showMixedChannelWarning.value = false + mixedChannelWarningDetails.value = null + mixedChannelWarningRawMessage.value = '' + mixedChannelWarningAction.value = null +} + +const openMixedChannelDialog = (opts: { + response?: CheckMixedChannelResponse + message?: string + onConfirm: () => Promise +}) => { + mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response) + mixedChannelWarningRawMessage.value = + opts.message || opts.response?.message || t('admin.accounts.failedToUpdate') + mixedChannelWarningAction.value = opts.onConfirm + showMixedChannelWarning.value = true +} + +const withAntigravityConfirmFlag = (payload: Record) => { + if (needsMixedChannelCheck() && antigravityMixedChannelConfirmed.value) { + return { + ...payload, + confirm_mixed_channel_risk: true + } + } + const cloned = { ...payload } + delete cloned.confirm_mixed_channel_risk + return cloned +} + +const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise): Promise => { + if (!needsMixedChannelCheck()) { + return true + } + if (antigravityMixedChannelConfirmed.value) { + return true + } + if (!props.account) { + return false + } + + try { + const result = await adminAPI.accounts.checkMixedChannelRisk({ + platform: props.account.platform, + group_ids: form.group_ids, + account_id: props.account.id + }) + if (!result.has_risk) { + return true + } + openMixedChannelDialog({ + response: result, + onConfirm: async () => { + antigravityMixedChannelConfirmed.value = true + await onConfirm() + } + }) + return false + } catch (error: any) { + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) + return false + } +} + const formatDateTimeLocal = formatDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput // Methods const handleClose = () => { + antigravityMixedChannelConfirmed.value = false + clearMixedChannelDialog() emit('close') } +const submitUpdateAccount = async (accountID: number, updatePayload: Record) => { + submitting.value = true + try { + const updatedAccount = await adminAPI.accounts.update(accountID, withAntigravityConfirmFlag(updatePayload)) + appStore.showSuccess(t('admin.accounts.accountUpdated')) + emit('updated', updatedAccount) + handleClose() + } catch (error: any) { + if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck()) { + openMixedChannelDialog({ + message: error.response?.data?.message, + onConfirm: async () => { + antigravityMixedChannelConfirmed.value = true + await submitUpdateAccount(accountID, updatePayload) + } + }) + return + } + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) + } finally { + submitting.value = false + } +} + const handleSubmit = async () => { if (!props.account) return + const accountID = props.account.id - submitting.value = true const updatePayload: Record = { ...form } try { // 后端期望 proxy_id: 0 表示清除代理,而不是 null @@ -1769,7 +1889,6 @@ const handleSubmit = async () => { newCredentials.api_key = currentCredentials.api_key } else { appStore.showError(t('admin.accounts.apiKeyIsRequired')) - submitting.value = false return } @@ -1792,7 +1911,6 @@ const handleSubmit = async () => { // Add intercept warmup requests setting applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') if (!applyTempUnschedConfig(newCredentials)) { - submitting.value = false return } @@ -1811,7 +1929,6 @@ const handleSubmit = async () => { applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') if (!applyTempUnschedConfig(newCredentials)) { - submitting.value = false return } @@ -1823,7 +1940,6 @@ const handleSubmit = async () => { applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') if (!applyTempUnschedConfig(newCredentials)) { - submitting.value = false return } @@ -1953,52 +2069,36 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } - const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload) - appStore.showSuccess(t('admin.accounts.accountUpdated')) - emit('updated', updatedAccount) - handleClose() - } catch (error: any) { - // Handle 409 mixed_channel_warning - show confirmation dialog - if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') { - const details = error.response.data.details || {} - mixedChannelWarningDetails.value = { - groupName: details.group_name || 'Unknown', - currentPlatform: details.current_platform || 'Unknown', - otherPlatform: details.other_platform || 'Unknown' - } - pendingUpdatePayload.value = updatePayload - showMixedChannelWarning.value = true - } else { - appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) + const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => { + await submitUpdateAccount(accountID, updatePayload) + }) + if (!canContinue) { + return } - } finally { - submitting.value = false + + await submitUpdateAccount(accountID, updatePayload) + } catch (error: any) { + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) } } // Handle mixed channel warning confirmation const handleMixedChannelConfirm = async () => { - showMixedChannelWarning.value = false - if (pendingUpdatePayload.value && props.account) { - pendingUpdatePayload.value.confirm_mixed_channel_risk = true - submitting.value = true - try { - const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value) - appStore.showSuccess(t('admin.accounts.accountUpdated')) - emit('updated', updatedAccount) - handleClose() - } catch (error: any) { - appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) - } finally { - submitting.value = false - pendingUpdatePayload.value = null - } + const action = mixedChannelWarningAction.value + if (!action) { + clearMixedChannelDialog() + return + } + clearMixedChannelDialog() + submitting.value = true + try { + await action() + } finally { + submitting.value = false } } const handleMixedChannelCancel = () => { - showMixedChannelWarning.value = false - pendingUpdatePayload.value = null - mixedChannelWarningDetails.value = null + clearMixedChannelDialog() } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1284c176..70fe5a27 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -766,6 +766,26 @@ export interface UpdateAccountRequest { confirm_mixed_channel_risk?: boolean } +export interface CheckMixedChannelRequest { + platform: AccountPlatform + group_ids: number[] + account_id?: number +} + +export interface MixedChannelWarningDetails { + group_id: number + group_name: string + current_platform: string + other_platform: string +} + +export interface CheckMixedChannelResponse { + has_risk: boolean + error?: string + message?: string + details?: MixedChannelWarningDetails +} + export interface CreateProxyRequest { name: string protocol: ProxyProtocol From 09166a52f89cb873b75a9061285b1864d57ee015 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:08:04 +0800 Subject: [PATCH 05/24] refactor: extract failover error handling into FailoverState - Extract duplicated failover logic from gateway_handler.go (3 places) and gemini_v1beta_handler.go into shared failover_loop.go - Introduce FailoverState with HandleFailoverError and HandleSelectionExhausted - Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go - Add comprehensive unit tests (32+ test cases) - Delete redundant gateway_handler_single_account_retry_test.go --- backend/internal/handler/failover_loop.go | 160 ++++ .../internal/handler/failover_loop_test.go | 732 ++++++++++++++++++ backend/internal/handler/gateway_handler.go | 260 ++----- ...teway_handler_single_account_retry_test.go | 51 -- .../internal/handler/gemini_v1beta_handler.go | 75 +- 5 files changed, 975 insertions(+), 303 deletions(-) create mode 100644 backend/internal/handler/failover_loop.go create mode 100644 backend/internal/handler/failover_loop_test.go delete mode 100644 backend/internal/handler/gateway_handler_single_account_retry_test.go diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go new file mode 100644 index 00000000..1f8a7e9a --- /dev/null +++ b/backend/internal/handler/failover_loop.go @@ -0,0 +1,160 @@ +package handler + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。 +// GatewayService 隐式实现此接口。 +type TempUnscheduler interface { + TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) +} + +// FailoverAction 表示 failover 错误处理后的下一步动作 +type FailoverAction int + +const ( + // FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue) + FailoverContinue FailoverAction = iota + // FailoverExhausted 切换次数耗尽(调用方应返回错误响应) + FailoverExhausted + // FailoverCanceled context 已取消(调用方应直接 return) + FailoverCanceled +) + +const ( + // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) + maxSameAccountRetries = 2 + // sameAccountRetryDelay 同账号重试间隔 + sameAccountRetryDelay = 500 * time.Millisecond + // singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。 + // Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s), + // Handler 层只需短暂间隔后重新进入 Service 层即可。 + singleAccountBackoffDelay = 2 * time.Second +) + +// FailoverState 跨循环迭代共享的 failover 状态 +type FailoverState struct { + SwitchCount int + MaxSwitches int + FailedAccountIDs map[int64]struct{} + SameAccountRetryCount map[int64]int + LastFailoverErr *service.UpstreamFailoverError + ForceCacheBilling bool + hasBoundSession bool +} + +// NewFailoverState 创建 failover 状态 +func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState { + return &FailoverState{ + MaxSwitches: maxSwitches, + FailedAccountIDs: make(map[int64]struct{}), + SameAccountRetryCount: make(map[int64]int), + hasBoundSession: hasBoundSession, + } +} + +// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。 +// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。 +func (s *FailoverState) HandleFailoverError( + ctx context.Context, + gatewayService TempUnscheduler, + accountID int64, + platform string, + failoverErr *service.UpstreamFailoverError, +) FailoverAction { + s.LastFailoverErr = failoverErr + + // 缓存计费判断 + if needForceCacheBilling(s.hasBoundSession, failoverErr) { + s.ForceCacheBilling = true + } + + // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 + if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries { + s.SameAccountRetryCount[accountID]++ + log.Printf("Account %d: retryable error %d, same-account retry %d/%d", + accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries) + if !sleepWithContext(ctx, sameAccountRetryDelay) { + return FailoverCanceled + } + return FailoverContinue + } + + // 同账号重试用尽,执行临时封禁 + if failoverErr.RetryableOnSameAccount { + gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr) + } + + // 加入失败列表 + s.FailedAccountIDs[accountID] = struct{}{} + + // 检查是否耗尽 + if s.SwitchCount >= s.MaxSwitches { + return FailoverExhausted + } + + // 递增切换计数 + s.SwitchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", + accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches) + + // Antigravity 平台换号线性递增延时 + if platform == service.PlatformAntigravity { + delay := time.Duration(s.SwitchCount-1) * time.Second + if !sleepWithContext(ctx, delay) { + return FailoverCanceled + } + } + + return FailoverContinue +} + +// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。 +// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景: +// 清除排除列表、等待退避后重新选号。 +// +// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。 +// 返回 FailoverExhausted 时,调用方应返回错误响应。 +// 返回 FailoverCanceled 时,调用方应直接 return。 +func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction { + if s.LastFailoverErr != nil && + s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable && + s.SwitchCount <= s.MaxSwitches { + + log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", + singleAccountBackoffDelay, s.SwitchCount) + if !sleepWithContext(ctx, singleAccountBackoffDelay) { + return FailoverCanceled + } + log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", + s.SwitchCount, s.MaxSwitches) + s.FailedAccountIDs = make(map[int64]struct{}) + return FailoverContinue + } + return FailoverExhausted +} + +// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。 +// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。 +func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { + return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) +} + +// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。 +func sleepWithContext(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(d): + return true + } +} diff --git a/backend/internal/handler/failover_loop_test.go b/backend/internal/handler/failover_loop_test.go new file mode 100644 index 00000000..5a41b2dd --- /dev/null +++ b/backend/internal/handler/failover_loop_test.go @@ -0,0 +1,732 @@ +package handler + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock +// --------------------------------------------------------------------------- + +// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。 +type mockTempUnscheduler struct { + calls []tempUnscheduleCall +} + +type tempUnscheduleCall struct { + accountID int64 + failoverErr *service.UpstreamFailoverError +} + +func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) { + m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr}) +} + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError { + return &service.UpstreamFailoverError{ + StatusCode: statusCode, + RetryableOnSameAccount: retryable, + ForceCacheBilling: forceBilling, + } +} + +// --------------------------------------------------------------------------- +// NewFailoverState 测试 +// --------------------------------------------------------------------------- + +func TestNewFailoverState(t *testing.T) { + t.Run("初始化字段正确", func(t *testing.T) { + fs := NewFailoverState(5, true) + require.Equal(t, 5, fs.MaxSwitches) + require.Equal(t, 0, fs.SwitchCount) + require.NotNil(t, fs.FailedAccountIDs) + require.Empty(t, fs.FailedAccountIDs) + require.NotNil(t, fs.SameAccountRetryCount) + require.Empty(t, fs.SameAccountRetryCount) + require.Nil(t, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.True(t, fs.hasBoundSession) + }) + + t.Run("无绑定会话", func(t *testing.T) { + fs := NewFailoverState(3, false) + require.Equal(t, 3, fs.MaxSwitches) + require.False(t, fs.hasBoundSession) + }) + + t.Run("零最大切换次数", func(t *testing.T) { + fs := NewFailoverState(0, false) + require.Equal(t, 0, fs.MaxSwitches) + }) +} + +// --------------------------------------------------------------------------- +// sleepWithContext 测试 +// --------------------------------------------------------------------------- + +func TestSleepWithContext(t *testing.T) { + t.Run("零时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 0) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("负时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), -1*time.Second) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("正常等待后返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 50*time.Millisecond) + elapsed := time.Since(start) + require.True(t, ok) + require.GreaterOrEqual(t, elapsed, 40*time.Millisecond) + require.Less(t, elapsed, 500*time.Millisecond) + }) + + t.Run("已取消context立即返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + require.False(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("等待期间context取消返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + elapsed := time.Since(start) + require.False(t, ok) + require.Less(t, elapsed, 500*time.Millisecond) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 基本切换流程 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_BasicSwitch(t *testing.T) { + t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Equal(t, err, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.Empty(t, mock.calls, "不应调用 TempUnschedule") + }) + + t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) { + // switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0 + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0") + }) + + t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) { + // switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 模拟已切换一次 + + err := newTestFailoverErr(500, false, false) + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s") + require.Less(t, elapsed, 3*time.Second) + }) + + t.Run("连续切换直到耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + // 第一次切换:0→1 + err1 := newTestFailoverErr(500, false, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + + // 第二次切换:1→2 + err2 := newTestFailoverErr(502, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2) + err3 := newTestFailoverErr(503, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增") + + // 验证失败账号列表 + require.Len(t, fs.FailedAccountIDs, 3) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Contains(t, fs.FailedAccountIDs, int64(300)) + + // LastFailoverErr 应为最后一次的错误 + require.Equal(t, err3, fs.LastFailoverErr) + }) + + t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 0, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 缓存计费 (ForceCacheBilling) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_CacheBilling(t *testing.T) { + t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("两者均为false时不设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.False(t, fs.ForceCacheBilling) + }) + + t.Run("一旦设置不会被后续错误重置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + // 第一次:ForceCacheBilling=true → 设置 + err1 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.True(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=false → 仍然保持 true + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 同账号重试 (RetryableOnSameAccount) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_SameAccountRetry(t *testing.T) { + t.Run("第一次重试返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数") + require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表") + require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule") + // 验证等待了 sameAccountRetryDelay (500ms) + require.GreaterOrEqual(t, elapsed, 400*time.Millisecond) + require.Less(t, elapsed, 2*time.Second) + }) + + t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 第二次 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule") + }) + + t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次、第二次重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + // 第三次:重试已达到 maxSameAccountRetries(2),应切换账号 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + // 验证 TempUnschedule 被调用 + require.Len(t, mock.calls, 1) + require.Equal(t, int64(100), mock.calls[0].accountID) + require.Equal(t, err, mock.calls[0].failoverErr) + }) + + t.Run("不同账号独立跟踪重试次数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 账号 100 第一次重试 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 账号 200 第一次重试(独立计数) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[200]) + require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响") + }) + + t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 耗尽账号 100 的重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + // 第三次: 重试耗尽 → 切换 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + + // 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — TempUnschedule 调用验证 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_TempUnschedule(t *testing.T) { + t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Empty(t, mock.calls) + }) + + t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(502, true, false) + + // 耗尽重试 + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + + require.Len(t, mock.calls, 1) + require.Equal(t, int64(42), mock.calls[0].accountID) + require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode) + require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — Context 取消 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_ContextCanceled(t *testing.T) { + t.Run("同账号重试sleep期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + // 重试计数仍应递增 + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + }) + + t.Run("Antigravity延迟期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s + err := newTestFailoverErr(500, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — FailedAccountIDs 跟踪 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_FailedAccountIDs(t *testing.T) { + t.Run("切换时添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Len(t, fs.FailedAccountIDs, 2) + }) + + t.Run("耗尽时也添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Equal(t, FailoverExhausted, action) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false)) + require.Equal(t, FailoverContinue, action) + require.NotContains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同一账号多次切换不重复添加", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — LastFailoverErr 更新 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_LastFailoverErr(t *testing.T) { + t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, err1, fs.LastFailoverErr) + + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, err2, fs.LastFailoverErr) + }) + + t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err := newTestFailoverErr(400, true, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, err, fs.LastFailoverErr) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 综合集成场景 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_IntegrationScenario(t *testing.T) { + t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + + // 1. 账号 100 遇到可重试错误,同账号重试 2 次 + retryErr := newTestFailoverErr(400, true, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling") + + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + + // 2. 账号 100 重试耗尽 → TempUnschedule + 切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Len(t, mock.calls, 1) + + // 3. 账号 200 遇到不可重试错误 → 直接切换 + switchErr := newTestFailoverErr(500, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 4. 账号 300 遇到不可重试错误 → 再切换 + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 3, fs.SwitchCount) + + // 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3) + action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr) + require.Equal(t, FailoverExhausted, action) + + // 最终状态验证 + require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增") + require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中") + require.True(t, fs.ForceCacheBilling) + require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule") + }) + + t.Run("模拟Antigravity平台完整流程", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + err := newTestFailoverErr(500, false, false) + + // 第一次切换:delay = 0s + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0") + + // 第二次切换:delay = 1s + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverContinue, action) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s") + + // 第三次:耗尽(无延迟,因为在检查延迟之前就返回了) + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟") + }) + + t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) // hasBoundSession=false + + // 第一次:ForceCacheBilling=false + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.False(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换) + err2 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling") + + // 第三次:ForceCacheBilling=false,但状态仍保持 true + err3 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.True(t, fs.ForceCacheBilling, "不应重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 边界条件 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_EdgeCases(t *testing.T) { + t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(0, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + }) + + t.Run("AccountID为0也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[0]) + }) + + t.Run("负AccountID也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[-1]) + }) + + t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟") + }) +} + +// --------------------------------------------------------------------------- +// HandleSelectionExhausted 测试 +// --------------------------------------------------------------------------- + +func TestHandleSelectionExhausted(t *testing.T) { + t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + // LastFailoverErr 为 nil + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("非503错误返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(500, false, false) + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.FailedAccountIDs[100] = struct{}{} + fs.SwitchCount = 1 + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表") + require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s") + require.Less(t, elapsed, 5*time.Second) + }) + + t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 3 // > MaxSwitches(2) + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 100*time.Millisecond, "不应等待") + }) + + t.Run("503但context已取消_返回Canceled", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + action := fs.HandleSelectionExhausted(ctx) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + }) + + t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试 + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverContinue, action) + }) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4b32969f..d78f83a6 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "log" "net/http" "strings" "time" @@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 if platform == service.PlatformGemini { - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 @@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { - reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + if len(fs.FailedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - reqLog.Warn("gateway.single_account_retrying", - zap.Int("retry_count", switchCount), - zap.Int("max_retries", maxAccountSwitches), - ) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) } + return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) - } - return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) @@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) } if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) @@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - lastFailoverErr = failoverErr - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - - // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 - if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { - sameAccountRetryCount[account.ID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) - if !sleepSameAccountRetryDelay(c.Request.Context()) { - return - } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: continue - } - - // 同账号重试用尽,执行临时封禁并切换账号 - if failoverErr.RetryableOnSameAccount { - h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) - } - - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - reqLog.Warn("gateway.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) reqLog.Error("gateway.forward_failed", @@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: forceCacheBilling, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } for { - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 - var lastFailoverErr *service.UpstreamFailoverError + fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession) retryWithFallback := false - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) if err != nil { - if len(failedAccountIDs) == 0 { - reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + if len(fs.FailedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - reqLog.Warn("gateway.single_account_retrying", - zap.Int("retry_count", switchCount), - zap.Int("max_retries", maxAccountSwitches), - ) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) } + return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) - } - return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) @@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) @@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - lastFailoverErr = failoverErr - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - - // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 - if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { - sameAccountRetryCount[account.ID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) - if !sleepSameAccountRetryDelay(c.Request.Context()) { - return - } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: continue - } - - // 同账号重试用尽,执行临时封禁并切换账号 - if failoverErr.RetryableOnSameAccount { - h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) - } - - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - reqLog.Warn("gateway.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) reqLog.Error("gateway.forward_failed", @@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Subscription: currentSubscription, UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: forceCacheBilling, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -735,7 +652,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { }) reqLog.Debug("gateway.request_completed", zap.Int64("account_id", account.ID), - zap.Int("switch_count", switchCount), + zap.Int("switch_count", fs.SwitchCount), zap.Bool("fallback_used", fallbackUsed), ) return @@ -982,69 +899,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -// needForceCacheBilling 判断 failover 时是否需要强制缓存计费 -// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费 -func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { - return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) -} - -const ( - // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) - maxSameAccountRetries = 2 - // sameAccountRetryDelay 同账号重试间隔 - sameAccountRetryDelay = 500 * time.Millisecond -) - -// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。 -func sleepSameAccountRetryDelay(ctx context.Context) bool { - select { - case <-ctx.Done(): - return false - case <-time.After(sameAccountRetryDelay): - return true - } -} - -// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s… -// 返回 false 表示 context 已取消。 -func sleepFailoverDelay(ctx context.Context, switchCount int) bool { - delay := time.Duration(switchCount-1) * time.Second - if delay <= 0 { - return true - } - select { - case <-ctx.Done(): - return false - case <-time.After(delay): - return true - } -} - -// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。 -// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用, -// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试 -// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。 -// 返回 false 表示 context 已取消。 -func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool { - // 固定短延时:2s - // Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数), - // Handler 层只需短暂间隔后重新进入 Service 层即可。 - const delay = 2 * time.Second - - logger.L().With( - zap.String("component", "handler.gateway.failover"), - zap.Duration("delay", delay), - zap.Int("retry_count", retryCount), - ).Info("gateway.single_account_backoff_waiting") - - select { - case <-ctx.Done(): - return false - case <-time.After(delay): - return true - } -} - func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody diff --git a/backend/internal/handler/gateway_handler_single_account_retry_test.go b/backend/internal/handler/gateway_handler_single_account_retry_test.go deleted file mode 100644 index 96aa14c6..00000000 --- a/backend/internal/handler/gateway_handler_single_account_retry_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package handler - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -// --------------------------------------------------------------------------- -// sleepAntigravitySingleAccountBackoff 测试 -// --------------------------------------------------------------------------- - -func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) { - ctx := context.Background() - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 1) - elapsed := time.Since(start) - - require.True(t, ok, "should return true when context is not canceled") - // 固定延迟 2s - require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s") - require.Less(t, elapsed, 5*time.Second, "should not wait too long") -} - -func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() // 立即取消 - - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 1) - elapsed := time.Since(start) - - require.False(t, ok, "should return false when context is canceled") - require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel") -} - -func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) { - // 验证不同 retryCount 都使用固定 2s 延迟 - ctx := context.Background() - - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 5) - elapsed := time.Since(start) - - require.True(t, ok) - // 即使 retryCount=5,延迟仍然是固定的 2s - require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond) - require.Less(t, elapsed, 5*time.Second) -} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ea212088..2da0570b 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 cleanedForUnknownBinding := false - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 @@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - reqLog.Warn("gemini.single_account_retrying", - zap.Int("retry_count", switchCount), - zap.Int("max_retries", maxAccountSwitches), - ) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue - } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return } - h.handleGeminiFailoverExhausted(c, lastFailoverErr) - return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) @@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) @@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - if switchCount >= maxAccountSwitches { - lastFailoverErr = failoverErr - h.handleGeminiFailoverExhausted(c, lastFailoverErr) + failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch failoverAction { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return + case FailoverCanceled: return } - lastFailoverErr = failoverErr - switchCount++ - reqLog.Warn("gemini.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } // ForwardNative already wrote the response reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) @@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { IPAddress: clientIP, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: forceCacheBilling, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { }) reqLog.Debug("gemini.request_completed", zap.Int64("account_id", account.ID), - zap.Int("switch_count", switchCount), + zap.Int("switch_count", fs.SwitchCount), ) return } From 4573868c08e7c0b82fda86596ea762d987a324e1 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:08:19 +0800 Subject: [PATCH 06/24] fix(antigravity): bill with mapped model and use final model key for rate limiting - Use mapped model (billingModel) instead of original request model for billing - Use resolveFinalAntigravityModelKey for 429 rate limit model key, ensuring rate limit records match the actual upstream model - Add regression tests for both fixes --- .../service/antigravity_gateway_service.go | 23 ++- .../antigravity_gateway_service_test.go | 142 +++++++++++++++++- .../service/antigravity_rate_limit_test.go | 16 ++ 3 files changed, 172 insertions(+), 9 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 26b14e68..108ff9ab 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -87,7 +87,6 @@ var ( ) const ( - antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" ) @@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -1622,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 + Model: billingModel, // 使用映射模型用于计费和日志 Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1976,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if mappedModel == "" { return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) } + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -2205,7 +2206,7 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, + Model: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -2650,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError( defaultDur := s.getDefaultRateLimitDuration() // 尝试解析模型 key 并设置模型级限流 - modelKey := resolveAntigravityModelKey(requestedModel) + // + // 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6), + // 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。 + // 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。 + modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel) + if strings.TrimSpace(modelKey) == "" { + // 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过), + // 保持旧行为作为兜底,避免完全丢失模型级限流记录。 + modelKey = resolveAntigravityModelKey(requestedModel) + } if modelKey != "" { ra := s.resolveResetTime(resetAt, defaultDur) if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { @@ -3889,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, fmt.Errorf("missing model") } originalModel := claudeReq.Model - billingModel := originalModel // 构建上游请求 URL upstreamURL := baseURL + "/v1/messages" @@ -3942,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. _, _ = c.Writer.Write(respBody) return &ForwardResult{ - Model: billingModel, + Model: originalModel, }, nil } @@ -3983,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds()) return &ForwardResult{ - Model: billingModel, + Model: originalModel, Stream: claudeReq.Stream, Duration: duration, FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index abe7b75d..84b65adc 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, return s.resp, s.err } +type antigravitySettingRepoStub struct{} + +func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", ErrSettingNotFound +} + +func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { } svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: &httpUpstreamStub{resp: resp}, + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, } account := &Account{ @@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } +// TestAntigravityGatewayService_Forward_BillsWithMappedModel +// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + "max_tokens": 16, + "stream": true, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-1"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 5, + Name: "acc-forward-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "claude-sonnet-4-5": mappedModel, + }, + }, + } + + result, err := svc.Forward(context.Background(), c, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + +// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel +// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-2"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 6, + Name: "acc-gemini-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "gemini-2.5-flash": mappedModel, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + // TestStreamUpstreamResponse_UsageAndFirstToken // 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 6a486ebc..dd8dd83f 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } +// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景 +// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking) +func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("5s") + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false) + + require.Nil(t, result) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey) +} + // TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景 // MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号 func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) { From 644058174e9ed212e5d2658c5a610d33b4763673 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:54:59 +0800 Subject: [PATCH 07/24] fix(gemini): enable model_mapping filtering for Gemini API Key accounts Remove the special case that bypassed model-supported checks for Gemini API Key accounts, allowing model_mapping to filter requests properly. Add tests for multiplatform model filtering behavior. --- .../service/gateway_multiplatform_test.go | 79 +++++++++++++++++++ backend/internal/service/gateway_service.go | 4 - 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 70d5068b..5055eec0 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -895,6 +895,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t require.Equal(t, int64(2), acc.ID) } +func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}}, + }, + { + ID: 2, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Priority: 2, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}}, + }, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号") + + acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "supporting model") +} + func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) { ctx := context.Background() groupID := int64(50) @@ -1070,6 +1119,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { model: "claude-3-5-sonnet-20241022", expected: true, }, + { + name: "Gemini平台-无映射配置-支持所有模型", + account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey}, + model: "gemini-2.5-flash", + expected: true, + }, + { + name: "Gemini平台-有映射配置-只支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + }, + }, + model: "gemini-2.5-flash", + expected: false, + }, + { + name: "Gemini平台-有映射配置-支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + }, + }, + model: "gemini-2.5-pro", + expected: true, + }, } for _, tt := range tests { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e55940ee..5c14e7f9 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2825,10 +2825,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) } - // Gemini API Key 账户直接透传,由上游判断模型是否支持 - if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { - return true - } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) } From 86bc76e352e0728d45dbf043ff7ff5a934900d98 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:55:11 +0800 Subject: [PATCH 08/24] test: add warmup request interception unit tests Add comprehensive tests for warmup request interception behavior covering Antigravity accounts with various credential configurations. --- ...eway_handler_warmup_intercept_unit_test.go | 340 ++++++++++++++++++ 1 file changed, 340 insertions(+) create mode 100644 backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go new file mode 100644 index 00000000..15d85949 --- /dev/null +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -0,0 +1,340 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”, +// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时, +// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。 + +type fakeSchedulerCache struct { + accounts []*service.Account +} + +func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) { + return f.accounts, true, nil +} +func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { + return nil +} +func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { + return nil, nil +} +func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } +func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil } +func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error { + return nil +} +func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { + return true, nil +} +func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { + return nil, nil +} +func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil } +func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil } + +type fakeGroupRepo struct { + group *service.Group +} + +func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil } +func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) { + return f.group, nil +} +func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) { + return f.group, nil +} +func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil } +func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil } +func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil } +func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil } +func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) { + return nil, nil +} +func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } +func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } +func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + return nil, nil +} +func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil } +func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error { + return nil +} + +type fakeConcurrencyCache struct{} + +func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil } +func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) { + return 0, nil +} +func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) { + return 0, nil +} +func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil } +func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil } +func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} +func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} +func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } + +func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { + t.Helper() + + schedulerCache := &fakeSchedulerCache{accounts: accounts} + schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil) + + gwSvc := service.NewGatewayService( + nil, // accountRepo (not used: scheduler snapshot hit) + &fakeGroupRepo{group: group}, + nil, // usageLogRepo + nil, // userRepo + nil, // userSubRepo + nil, // userGroupRateRepo + nil, // cache (disable sticky) + nil, // cfg + schedulerSnapshot, + nil, // concurrencyService (disable load-aware; tryAcquire always acquired) + nil, // billingService + nil, // rateLimitService + nil, // billingCacheService + nil, // identityService + nil, // httpUpstream + nil, // deferredService + nil, // claudeTokenProvider + nil, // sessionLimitCache + nil, // digestStore + ) + + // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 + cfg := &config.Config{RunMode: config.RunModeSimple} + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + + concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) + concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) + + h := &GatewayHandler{ + gatewayService: gwSvc, + billingCacheService: billingCacheSvc, + concurrencyHelper: concurrencyHelper, + // 这些字段对本测试不敏感,保持较小即可 + maxAccountSwitches: 1, + maxAccountSwitchesGemini: 1, + } + + cleanup := func() { + billingCacheSvc.Stop() + } + return h, cleanup +} + +func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(2001) + accountID := int64(1001) + + group := &service.Group{ + ID: groupID, + Hydrated: true, + Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口 + Status: service.StatusActive, + } + + account := &service.Account{ + ID: accountID, + Name: "ag-1", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "tok_xxx", + "intercept_warmup_requests": true, + }, + Extra: map[string]any{ + "mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中 + }, + Concurrency: 1, + Priority: 1, + Status: service.StatusActive, + Schedulable: true, + AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, + } + + h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) + defer cleanup() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{ + "model": "claude-sonnet-4-5", + "max_tokens": 256, + "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] + }`) + req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group)) + c.Request = req + + apiKey := &service.APIKey{ + ID: 3001, + UserID: 4001, + GroupID: &groupID, + Status: service.StatusActive, + User: &service.User{ + ID: 4001, + Concurrency: 10, + Balance: 100, + }, + Group: group, + } + + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) + + h.Messages(c) + + require.Equal(t, 200, rec.Code) + + // 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果) + selected, ok := c.Get(opsAccountIDKey) + require.True(t, ok) + require.Equal(t, accountID, selected) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "msg_mock_warmup", resp["id"]) + require.Equal(t, "claude-sonnet-4-5", resp["model"]) + + content, ok := resp["content"].([]any) + require.True(t, ok) + require.Len(t, content, 1) + first, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "New Conversation", first["text"]) +} + +func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(2002) + accountID := int64(1002) + + group := &service.Group{ + ID: groupID, + Hydrated: true, + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + } + + account := &service.Account{ + ID: accountID, + Name: "ag-2", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "tok_xxx", + "intercept_warmup_requests": true, + }, + Concurrency: 1, + Priority: 1, + Status: service.StatusActive, + Schedulable: true, + AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, + } + + h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) + defer cleanup() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{ + "model": "claude-sonnet-4-5", + "max_tokens": 256, + "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] + }`) + req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + // 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果: + // - 写入 request.Context(Service读取) + // - 写入 gin.Context(Handler快速读取) + ctx := context.WithValue(req.Context(), ctxkey.Group, group) + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity) + req = req.WithContext(ctx) + c.Request = req + c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity) + + apiKey := &service.APIKey{ + ID: 3002, + UserID: 4002, + GroupID: &groupID, + Status: service.StatusActive, + User: &service.User{ + ID: 4002, + Concurrency: 10, + Balance: 100, + }, + Group: group, + } + + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) + + h.Messages(c) + + require.Equal(t, 200, rec.Code) + + selected, ok := c.Get(opsAccountIDKey) + require.True(t, ok) + require.Equal(t, accountID, selected) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "msg_mock_warmup", resp["id"]) + require.Equal(t, "claude-sonnet-4-5", resp["model"]) +} From fb3ef5f388e5ee2f79a5b6137870e98873684409 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:55:25 +0800 Subject: [PATCH 09/24] fix(frontend): add Gemini models to bulk edit and fix status grid layout Add Gemini model presets to BulkEditAccountModal for bulk model mapping. Fix AccountStatusIndicator model rate limit grid layout using proper grid container. --- .../account/AccountStatusIndicator.vue | 4 +- .../account/BulkEditAccountModal.vue | 63 ++++++++++--------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 5fe96a1d..af32ea0c 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -77,7 +77,7 @@ - +
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 75fffc19..0997f6ee 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -209,7 +209,7 @@
('whitelist') const allowedModels = ref([]) const modelMappings = ref([]) -const getModelMappingKey = createStableObjectKeyResolver('bulk-model-mapping') const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) @@ -707,7 +706,7 @@ const rateMultiplier = ref(1) const status = ref<'active' | 'inactive'>('active') const groupIds = ref([]) -// All models list (combined Anthropic + OpenAI) +// All models list (combined Anthropic + OpenAI + Gemini) const allModels = [ { value: 'claude-opus-4-6', label: 'Claude Opus 4.6' }, { value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' }, @@ -719,17 +718,21 @@ const allModels = [ { value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' }, { value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' }, { value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' }, - { value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' }, { value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' }, { value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' }, { value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' }, { value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' }, { value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' }, { value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' }, - { value: 'gpt-5-2025-08-07', label: 'GPT-5' } + { value: 'gpt-5-2025-08-07', label: 'GPT-5' }, + { value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' }, + { value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' }, + { value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' }, + { value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' }, + { value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' } ] -// Preset mappings (combined Anthropic + OpenAI) +// Preset mappings (combined Anthropic + OpenAI + Gemini) const presetMappings = [ { label: 'Sonnet 4', @@ -771,12 +774,6 @@ const presetMappings = [ to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }, - { - label: 'GPT-5.3 Codex Spark', - from: 'gpt-5.3-codex-spark', - to: 'gpt-5.3-codex-spark', - color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' - }, { label: 'GPT-5.2', from: 'gpt-5.2-2025-12-11', @@ -794,6 +791,24 @@ const presetMappings = [ from: 'gpt-5.1-codex-max', to: 'gpt-5.1-codex', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' + }, + { + label: 'Gemini Flash 2.0', + from: 'gemini-2.0-flash', + to: 'gemini-2.0-flash', + color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' + }, + { + label: 'Gemini 2.5 Flash', + from: 'gemini-2.5-flash', + to: 'gemini-2.5-flash', + color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' + }, + { + label: 'Gemini 2.5 Pro', + from: 'gemini-2.5-pro', + to: 'gemini-2.5-pro', + color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' } ] @@ -883,23 +898,11 @@ const removeErrorCode = (code: number) => { } const buildModelMappingObject = (): Record | null => { - const mapping: Record = {} - - if (modelRestrictionMode.value === 'whitelist') { - for (const model of allowedModels.value) { - mapping[model] = model - } - } else { - for (const m of modelMappings.value) { - const from = m.from.trim() - const to = m.to.trim() - if (from && to) { - mapping[from] = to - } - } - } - - return Object.keys(mapping).length > 0 ? mapping : null + return buildModelMappingPayload( + modelRestrictionMode.value, + allowedModels.value, + modelMappings.value + ) } const buildUpdatePayload = (): Record | null => { From da6fd4500076ac9dffc84d9d8f7204a9a27ee405 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:55:39 +0800 Subject: [PATCH 10/24] chore: add sonnet-4-6 mapping, config defaults, and CI improvements - Add claude-sonnet-4-6 to default Antigravity model mapping - Add antigravity_extra_retries default value in config - Add cache-dependency-path to CI setup-go for faster builds - Simplify vitest config to avoid vite plugin compatibility issues --- .github/workflows/backend-ci.yml | 4 +++- backend/internal/config/config.go | 1 + backend/internal/domain/constants.go | 1 + .../internal/service/antigravity_model_mapping_test.go | 6 ++++++ frontend/vitest.config.ts | 9 +-------- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 4fd22aff..84575a96 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -17,6 +17,7 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.25.7' @@ -36,6 +37,7 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.25.7' @@ -44,4 +46,4 @@ jobs: with: version: v2.7 args: --timeout=5m - working-directory: backend \ No newline at end of file + working-directory: backend diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c4d4fdab..8cd77724 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1158,6 +1158,7 @@ func setDefaults() { viper.SetDefault("gateway.force_codex_cli", false) viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) + viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 5f273486..d8604abd 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型 "claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射 "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型 + "claude-sonnet-4-6": "claude-sonnet-4-6", "claude-sonnet-4-5": "claude-sonnet-4-5", "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", // Claude 详细版本 ID 映射 diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index f3621555..71939d26 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { }, // 3. 默认映射中的透传(映射到自己) + { + name: "默认映射透传 - claude-sonnet-4-6", + requestedModel: "claude-sonnet-4-6", + accountMapping: nil, + expected: "claude-sonnet-4-6", + }, { name: "默认映射透传 - claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5", diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index 1007f6ed..2ff23c77 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -1,18 +1,13 @@ import { defineConfig } from 'vitest/config' -import vue from '@vitejs/plugin-vue' import { resolve } from 'path' export default defineConfig({ - plugins: [vue()], resolve: { alias: { '@': resolve(__dirname, 'src'), 'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js' } }, - define: { - __INTLIFY_JIT_COMPILATION__: true - }, test: { globals: true, environment: 'jsdom', @@ -37,8 +32,6 @@ export default defineConfig({ lines: 80 } } - }, - setupFiles: ['./src/__tests__/setup.ts'], - testTimeout: 10000 + } } }) From 645f283108c8192706d0565652e6871809ccc327 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 19:30:01 +0800 Subject: [PATCH 11/24] feat: add claude-sonnet-4-6 and gemini-3.1-pro model support Add claude-sonnet-4-6 to identity injection modelInfoMap and Antigravity model selector. Add gemini-3.1-pro-high/low to Antigravity model list and Sonnet 4.6 preset mapping. --- backend/internal/pkg/antigravity/request_transformer.go | 1 + frontend/src/composables/useModelWhitelist.ts | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 3ba04b95..55cdd786 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -206,6 +206,7 @@ type modelInfo struct { var modelInfoMap = map[string]modelInfo{ "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, + "claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"}, "claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"}, "claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"}, } diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 80416f40..fc7bdc03 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -76,6 +76,7 @@ const antigravityModels = [ // Claude 4.5+ 系列 'claude-opus-4-6', 'claude-opus-4-5-thinking', + 'claude-sonnet-4-6', 'claude-sonnet-4-5', 'claude-sonnet-4-5-thinking', // Gemini 2.5 系列 @@ -88,6 +89,9 @@ const antigravityModels = [ 'gemini-3-pro-high', 'gemini-3-pro-low', 'gemini-3-pro-image', + // Gemini 3.1 系列 + 'gemini-3.1-pro-high', + 'gemini-3.1-pro-low', // 其他 'gpt-oss-120b-medium', 'tab_flash_lite_preview' @@ -295,6 +299,7 @@ const antigravityPresetMappings = [ { label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' }, { label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, // 精确映射 + { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' } ] From 483c8f246d77c89b8c907ccfd0573a62b6ea158f Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 19:39:15 +0800 Subject: [PATCH 12/24] chore: update default Antigravity UserAgent version to 1.18.4 Update the default ANTIGRAVITY_USER_AGENT_VERSION from 1.84.2 to 1.18.4 to match the current Antigravity-Manager desktop client. --- backend/internal/pkg/antigravity/oauth.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index ba84a247..e916859f 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -51,8 +51,8 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) -// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2 -var defaultUserAgentVersion = "1.84.2" +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4 +var defaultUserAgentVersion = "1.18.4" func init() { // 从环境变量读取版本号,未设置则使用默认值 From 29c406dda0505b69b4720e00ac6aa9b56dc61029 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 19:40:30 +0800 Subject: [PATCH 13/24] feat: add migrations for sonnet-4-6 and gemini-3.1-pro model mappings Add migration 058 to update existing Antigravity accounts with claude-sonnet-4-6 in model_mapping. Add migration 059 to add gemini-3.1-pro-high/low/preview mappings. --- .../058_add_sonnet46_to_model_mapping.sql | 42 +++++++++++++++++ .../059_add_gemini31_pro_to_model_mapping.sql | 45 +++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 backend/migrations/058_add_sonnet46_to_model_mapping.sql create mode 100644 backend/migrations/059_add_gemini31_pro_to_model_mapping.sql diff --git a/backend/migrations/058_add_sonnet46_to_model_mapping.sql b/backend/migrations/058_add_sonnet46_to_model_mapping.sql new file mode 100644 index 00000000..aa7657d7 --- /dev/null +++ b/backend/migrations/058_add_sonnet46_to_model_mapping.sql @@ -0,0 +1,42 @@ +-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts +-- +-- Background: +-- Antigravity now supports claude-sonnet-4-6 +-- +-- Strategy: +-- Directly overwrite the entire model_mapping with updated mappings +-- This ensures consistency with DefaultAntigravityModelMapping in constants.go + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/migrations/059_add_gemini31_pro_to_model_mapping.sql b/backend/migrations/059_add_gemini31_pro_to_model_mapping.sql new file mode 100644 index 00000000..6305e717 --- /dev/null +++ b/backend/migrations/059_add_gemini31_pro_to_model_mapping.sql @@ -0,0 +1,45 @@ +-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping +-- +-- Background: +-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low +-- +-- Strategy: +-- Directly overwrite the entire model_mapping with updated mappings +-- This ensures consistency with DefaultAntigravityModelMapping in constants.go + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; From 6523b2322114c0f0d9c0808799f3236b30a9edcc Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 19:45:23 +0800 Subject: [PATCH 14/24] revert: remove backend-ci.yml changes (fork-specific CI config) --- .github/workflows/backend-ci.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 84575a96..4fd22aff 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -17,7 +17,6 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true - cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.25.7' @@ -37,7 +36,6 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true - cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.25.7' @@ -46,4 +44,4 @@ jobs: with: version: v2.7 args: --timeout=5m - working-directory: backend + working-directory: backend \ No newline at end of file From f92ab4816609edea67e532466dede1b8539d9934 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 20:06:19 +0800 Subject: [PATCH 15/24] fix: add gemini-3.1-pro-preview to default Antigravity model mapping Add missing gemini-3.1-pro-preview -> gemini-3.1-pro-high mapping to DefaultAntigravityModelMapping for consistency with migration 059. --- backend/internal/domain/constants.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index d8604abd..53658270 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -94,8 +94,9 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-3-pro-low": "gemini-3.1-pro-low", "gemini-3-pro-image": "gemini-3-pro-image", // Gemini 3.1 透传 - "gemini-3.1-pro-high": "gemini-3.1-pro-high", - "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", // Gemini 3 preview 映射 "gemini-3-flash-preview": "gemini-3-flash", "gemini-3-pro-preview": "gemini-3.1-pro-high", From ca3e9336e11c2db3e09fc6881ec1b3c6ea698874 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 20:31:02 +0800 Subject: [PATCH 16/24] test: update UserAgent version assertion to match 1.18.4 default --- backend/internal/pkg/antigravity/oauth_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index 78184941..d8f2b098 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -676,7 +676,7 @@ func TestConstants_值正确(t *testing.T) { if RedirectURI != "http://localhost:8085/callback" { t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) } - if GetUserAgent() != "antigravity/1.84.2 windows/amd64" { + if GetUserAgent() != "antigravity/1.18.4 windows/amd64" { t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) } if SessionTTL != 30*time.Minute { From 36d2e6999b09b83c32ea0761961c329703666093 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 20:54:28 +0800 Subject: [PATCH 17/24] feat: add default value for Antigravity OAuth client secret Add a built-in default for ANTIGRAVITY_OAUTH_CLIENT_SECRET so the service works out of the box without requiring environment variable configuration. The env var can still override the default. --- backend/internal/pkg/antigravity/oauth.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index e916859f..ca59774b 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -54,11 +54,18 @@ const ( // defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4 var defaultUserAgentVersion = "1.18.4" +// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 +var defaultClientSecret = "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" + func init() { // 从环境变量读取版本号,未设置则使用默认值 if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" { defaultUserAgentVersion = version } + // 从环境变量读取 client_secret,未设置则使用默认值 + if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" { + defaultClientSecret = secret + } } // GetUserAgent 返回当前配置的 User-Agent @@ -67,14 +74,9 @@ func GetUserAgent() string { } func getClientSecret() (string, error) { - if v := strings.TrimSpace(ClientSecret); v != "" { + if v := strings.TrimSpace(defaultClientSecret); v != "" { return v, nil } - if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok { - if vv := strings.TrimSpace(v); vv != "" { - return vv, nil - } - } return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv) } From b6fa8b8eec438558a2e85650449c7e5c06602eee Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 21:06:10 +0800 Subject: [PATCH 18/24] fix: update tests for defaultClientSecret and align migration 058 - Fix oauth_test.go and client_test.go to use defaultClientSecret variable instead of env var (init() already sets the default) - Align migration 058 gemini-3-pro-high/low/preview mappings with constants.go (map to 3.1 versions) --- .../internal/pkg/antigravity/client_test.go | 52 ++++++++++++++----- .../internal/pkg/antigravity/oauth_test.go | 43 ++++++++++----- .../058_add_sonnet46_to_model_mapping.sql | 6 +-- 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index d3e2fd94..394b6128 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { // --------------------------------------------------------------------------- func TestClient_ExchangeCode_成功(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 验证请求方法 @@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) { } func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "") + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) client := NewClient("") _, err := client.ExchangeCode(context.Background(), "code", "verifier") @@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { } func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) @@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { // --------------------------------------------------------------------------- func TestClient_RefreshToken_MockServer(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) { } func TestClient_RefreshToken_无ClientSecret(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "") + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) client := NewClient("") _, err := client.RefreshToken(context.Background(), "refresh-tok") @@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client { // --------------------------------------------------------------------------- func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { } func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) @@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { } func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { } func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) // 模拟慢响应 @@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { // --------------------------------------------------------------------------- func TestClient_RefreshToken_Success_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) { } func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { } func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { } func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(5 * time.Second) diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index d8f2b098..da7d3fca 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "encoding/hex" "net/url" + "os" "strings" "testing" "time" @@ -17,8 +18,14 @@ import ( // --------------------------------------------------------------------------- func TestGetClientSecret_环境变量设置(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") + // 需要重新触发 init 逻辑:手动从环境变量读取 + defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv) + secret, err := getClientSecret() if err != nil { t.Fatalf("获取 client_secret 失败: %v", err) @@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) { } func TestGetClientSecret_环境变量为空(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, "") + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) _, err := getClientSecret() if err == nil { - t.Fatal("环境变量为空时应返回错误") + t.Fatal("defaultClientSecret 为空时应返回错误") } if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) @@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) { } func TestGetClientSecret_环境变量未设置(t *testing.T) { - // t.Setenv 会在测试结束时恢复,但我们需要确保它不存在 - // 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值 - // 当前代码中 ClientSecret = "",所以会走环境变量逻辑 - - // 明确设置再取消,确保环境变量不存在 - t.Setenv(AntigravityOAuthClientSecretEnv, "") + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) _, err := getClientSecret() if err == nil { - t.Fatal("环境变量未设置时应返回错误") + t.Fatal("defaultClientSecret 为空时应返回错误") } } func TestGetClientSecret_环境变量含空格(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, " ") + old := defaultClientSecret + defaultClientSecret = " " + t.Cleanup(func() { defaultClientSecret = old }) _, err := getClientSecret() if err == nil { - t.Fatal("环境变量仅含空格时应返回错误") + t.Fatal("defaultClientSecret 仅含空格时应返回错误") } } func TestGetClientSecret_环境变量有前后空格(t *testing.T) { - t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ") + old := defaultClientSecret + defaultClientSecret = " valid-secret " + t.Cleanup(func() { defaultClientSecret = old }) secret, err := getClientSecret() if err != nil { @@ -671,7 +681,14 @@ func TestConstants_值正确(t *testing.T) { t.Errorf("ClientID 不匹配: got %s", ClientID) } if ClientSecret != "" { - t.Error("ClientSecret 应为空字符串") + t.Error("ClientSecret 常量应为空字符串(默认值已移至 defaultClientSecret)") + } + secret, err := getClientSecret() + if err != nil { + t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) + } + if secret != "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" { + t.Errorf("默认 client_secret 不匹配: got %s", secret) } if RedirectURI != "http://localhost:8085/callback" { t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) diff --git a/backend/migrations/058_add_sonnet46_to_model_mapping.sql b/backend/migrations/058_add_sonnet46_to_model_mapping.sql index aa7657d7..93e7b39d 100644 --- a/backend/migrations/058_add_sonnet46_to_model_mapping.sql +++ b/backend/migrations/058_add_sonnet46_to_model_mapping.sql @@ -27,11 +27,11 @@ SET credentials = jsonb_set( "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", "gemini-2.5-pro": "gemini-2.5-pro", "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-pro-high": "gemini-3.1-pro-high", + "gemini-3-pro-low": "gemini-3.1-pro-low", "gemini-3-pro-image": "gemini-3-pro-image", "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3-pro-preview": "gemini-3.1-pro-high", "gemini-3-pro-image-preview": "gemini-3-pro-image", "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview" From d616f8c854df4a88b0c168a99368f491dbeaccac Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 21:09:46 +0800 Subject: [PATCH 19/24] refactor: remove unused ClientSecret constant The ClientSecret constant was left as an empty string after getClientSecret() was refactored to use defaultClientSecret. Remove the dead constant and update the test accordingly. --- backend/internal/pkg/antigravity/oauth.go | 4 +--- backend/internal/pkg/antigravity/oauth_test.go | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index ca59774b..47c75142 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -23,11 +23,9 @@ const ( UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo" // Antigravity OAuth 客户端凭证 - ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - ClientSecret = "" + ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" // AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。 - // 出于安全原因,该值不得硬编码入库。 AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET" // 固定的 redirect_uri(用户需手动复制 code) diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index da7d3fca..351708a5 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -680,9 +680,6 @@ func TestConstants_值正确(t *testing.T) { if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { t.Errorf("ClientID 不匹配: got %s", ClientID) } - if ClientSecret != "" { - t.Error("ClientSecret 常量应为空字符串(默认值已移至 defaultClientSecret)") - } secret, err := getClientSecret() if err != nil { t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) From d8d4b0c0c73ca3135fa20dade7efe657e0def97e Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 21:30:32 +0800 Subject: [PATCH 20/24] fix: enable Gemini model_mapping UI and extend warmup to Antigravity - Remove Gemini platform exclusion from model restriction UI in Create/Edit account modals (Gemini now supports model_mapping) - Remove outdated Gemini model passthrough info cards - Add model_mapping field to GeminiCredentials type - Extend warmup request interception toggle to Antigravity platform - Remove redundant try/catch in API key account creation - Remove noisy gateway.request_completed debug log - Reorganize Gemini model mapping sections in constants.go --- backend/internal/domain/constants.go | 9 ++-- backend/internal/handler/gateway_handler.go | 5 -- .../components/account/CreateAccountModal.vue | 51 +++---------------- .../components/account/EditAccountModal.vue | 36 ++----------- frontend/src/types/index.ts | 1 + 5 files changed, 18 insertions(+), 84 deletions(-) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 53658270..e674facb 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -93,14 +93,15 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-3-pro-high": "gemini-3.1-pro-high", "gemini-3-pro-low": "gemini-3.1-pro-low", "gemini-3-pro-image": "gemini-3-pro-image", - // Gemini 3.1 透传 - "gemini-3.1-pro-high": "gemini-3.1-pro-high", - "gemini-3.1-pro-low": "gemini-3.1-pro-low", - "gemini-3.1-pro-preview": "gemini-3.1-pro-high", // Gemini 3 preview 映射 "gemini-3-flash-preview": "gemini-3-flash", "gemini-3-pro-preview": "gemini-3.1-pro-high", "gemini-3-pro-image-preview": "gemini-3-pro-image", + // Gemini 3.1 白名单 + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + // Gemini 3.1 preview 映射 + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", // 其他官方模型 "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index d78f83a6..fe40e9d2 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -650,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ).Error("gateway.record_usage_failed", zap.Error(err)) } }) - reqLog.Debug("gateway.request_completed", - zap.Int64("account_id", account.ID), - zap.Int("switch_count", fs.SwitchCount), - zap.Bool("fallback_used", fallbackUsed), - ) return } if !retryWithFallback { diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 83b65159..25100c82 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -916,8 +916,8 @@

{{ t('admin.accounts.gemini.tier.aiStudioHint') }}

- -
+ +
- -
-
-
- - - -
-

- {{ t('admin.accounts.gemini.modelPassthrough') }} -

-

- {{ t('admin.accounts.gemini.modelPassthroughDesc') }} -

-
-
-
-
@@ -1378,9 +1350,9 @@
- +
@@ -2562,8 +2534,8 @@ watch( antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } - // Reset Anthropic-specific settings when switching to other platforms - if (newPlatform !== 'anthropic') { + // Reset Anthropic/Antigravity-specific settings when switching to other platforms + if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false } if (newPlatform === 'sora') { @@ -3117,15 +3089,8 @@ const handleSubmit = async () => { applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') - submitting.value = true - try { - const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined - await createAccountAndFinish(form.platform, 'apikey', credentials, extra) - } catch (error: any) { - appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) - } finally { - submitting.value = false - } + const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined + await createAccountAndFinish(form.platform, 'apikey', credentials, extra) return } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index c29aa54b..32508132 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -65,8 +65,8 @@

{{ t('admin.accounts.leaveEmptyToKeep') }}

- -
+ +
- -
-
-
- - - -
-

- {{ t('admin.accounts.gemini.modelPassthrough') }} -

-

- {{ t('admin.accounts.gemini.modelPassthroughDesc') }} -

-
-
-
-
@@ -641,9 +613,9 @@
- +
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 70fe5a27..a54cfcef 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -581,6 +581,7 @@ export interface GeminiCredentials { token_type?: string scope?: string expires_at?: string + model_mapping?: Record } export interface TempUnschedulableRule { From a3ff317f1ccfb595f7f94fa99346f032c54a07a8 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 22:11:50 +0800 Subject: [PATCH 21/24] feat: optimize model rate limit indicator layout with short aliases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change layout from fixed 3-column grid to vertical-first responsive columns (1 col for ≤4 items, 2 cols for ≤8, 3 cols for 9+) - Add short aliases for all known model scope keys (e.g. COpus46, CSon46, G3PH, G3F) to reduce badge width - Display countdown timer directly on each badge (supports h/m/s) - Retain legacy scope aliases for backward compatibility --- .../account/AccountStatusIndicator.vue | 63 ++++++++++++++++--- 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index af32ea0c..8816eb26 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -77,13 +77,23 @@
-
-
+
+
{{ formatScopeName(item.model) }} + {{ formatModelResetTime(item.reset_at) }}
{ }) const formatScopeName = (scope: string): string => { - const names: Record = { + const aliases: Record = { + // Claude 系列 + 'claude-opus-4-6-thinking': 'COpus46', + 'claude-sonnet-4-6': 'CSon46', + 'claude-sonnet-4-5': 'CSon45', + 'claude-sonnet-4-5-thinking': 'CSon45T', + // Gemini 2.5 系列 + 'gemini-2.5-flash': 'G25F', + 'gemini-2.5-flash-lite': 'G25FL', + 'gemini-2.5-flash-thinking': 'G25FT', + 'gemini-2.5-pro': 'G25P', + // Gemini 3 系列 + 'gemini-3-flash': 'G3F', + 'gemini-3.1-pro-high': 'G3PH', + 'gemini-3.1-pro-low': 'G3PL', + 'gemini-3-pro-image': 'G3PI', + // 其他 + 'gpt-oss-120b-medium': 'GPT120', + 'tab_flash_lite_preview': 'TabFL', + // 旧版 scope 别名(兼容) claude: 'Claude', - claude_sonnet: 'Claude Sonnet', - claude_opus: 'Claude Opus', - claude_haiku: 'Claude Haiku', + claude_sonnet: 'CSon', + claude_opus: 'COpus', + claude_haiku: 'CHaiku', gemini_text: 'Gemini', - gemini_image: 'Image', - gemini_flash: 'Gemini Flash', - gemini_pro: 'Gemini Pro' + gemini_image: 'GImg', + gemini_flash: 'GFlash', + gemini_pro: 'GPro', } - return names[scope] || scope + return aliases[scope] || scope +} + +const formatModelResetTime = (resetAt: string): string => { + const date = new Date(resetAt) + const now = new Date() + const diffMs = date.getTime() - now.getTime() + if (diffMs <= 0) return '' + const totalSecs = Math.floor(diffMs / 1000) + const h = Math.floor(totalSecs / 3600) + const m = Math.floor((totalSecs % 3600) / 60) + const s = totalSecs % 60 + if (h > 0) return `${h}h${m}m` + if (m > 0) return `${m}m${s}s` + return `${s}s` } // Computed: is overloaded (529) From c671e8dd1d4e6c6af786f8b301a3d89da539d85b Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 23:24:48 +0800 Subject: [PATCH 22/24] =?UTF-8?q?fix:=20=E7=BB=9F=E4=B8=80gemini-3?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E6=98=A0=E5=B0=84=E4=B8=BA=E9=9D=9E=E5=BC=BA?= =?UTF-8?q?=E5=88=B63.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/domain/constants.go | 6 +++--- backend/internal/service/model_rate_limit_test.go | 4 ++-- backend/migrations/058_add_sonnet46_to_model_mapping.sql | 6 +++--- frontend/src/composables/useModelWhitelist.ts | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index e674facb..c41aa65f 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -90,12 +90,12 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-2.5-pro": "gemini-2.5-pro", // Gemini 3 白名单 "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3.1-pro-high", - "gemini-3-pro-low": "gemini-3.1-pro-low", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", "gemini-3-pro-image": "gemini-3-pro-image", // Gemini 3 preview 映射 "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3.1-pro-high", + "gemini-3-pro-preview": "gemini-3-pro-high", "gemini-3-pro-image-preview": "gemini-3-pro-image", // Gemini 3.1 白名单 "gemini-3.1-pro-high": "gemini-3.1-pro-high", diff --git a/backend/internal/service/model_rate_limit_test.go b/backend/internal/service/model_rate_limit_test.go index c8f78ce3..b79b9688 100644 --- a/backend/internal/service/model_rate_limit_test.go +++ b/backend/internal/service/model_rate_limit_test.go @@ -107,12 +107,12 @@ func TestIsModelRateLimited(t *testing.T) { expected: true, }, { - name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3.1-pro-high", + name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high", account: &Account{ Platform: PlatformAntigravity, Extra: map[string]any{ modelRateLimitsKey: map[string]any{ - "gemini-3.1-pro-high": map[string]any{ + "gemini-3-pro-high": map[string]any{ "rate_limit_reset_at": future, }, }, diff --git a/backend/migrations/058_add_sonnet46_to_model_mapping.sql b/backend/migrations/058_add_sonnet46_to_model_mapping.sql index 93e7b39d..aa7657d7 100644 --- a/backend/migrations/058_add_sonnet46_to_model_mapping.sql +++ b/backend/migrations/058_add_sonnet46_to_model_mapping.sql @@ -27,11 +27,11 @@ SET credentials = jsonb_set( "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", "gemini-2.5-pro": "gemini-2.5-pro", "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3.1-pro-high", - "gemini-3-pro-low": "gemini-3.1-pro-low", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", "gemini-3-pro-image": "gemini-3-pro-image", "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3.1-pro-high", + "gemini-3-pro-preview": "gemini-3-pro-high", "gemini-3-pro-image-preview": "gemini-3-pro-image", "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview" diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index fc7bdc03..7779bf26 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -291,10 +291,10 @@ const antigravityPresetMappings = [ { label: 'Sonnet→Sonnet', from: 'claude-sonnet-*', to: 'claude-sonnet-4-5', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'Opus→Opus', from: 'claude-opus-*', to: 'claude-opus-4-6-thinking', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Haiku→Sonnet', from: 'claude-haiku-*', to: 'claude-sonnet-4-5', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }, - // Gemini 3→3.1 映射 - { label: '3-Pro-Preview→3.1-Pro-High', from: 'gemini-3-pro-preview', to: 'gemini-3.1-pro-high', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }, - { label: '3-Pro-High→3.1-Pro-High', from: 'gemini-3-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, - { label: '3-Pro-Low→3.1-Pro-Low', from: 'gemini-3-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' }, + // Gemini 3 映射 + { label: '3-Pro-Preview→3-Pro-High', from: 'gemini-3-pro-preview', to: 'gemini-3-pro-high', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }, + { label: '3-Pro-High', from: 'gemini-3-pro-high', to: 'gemini-3-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, + { label: '3-Pro-Low', from: 'gemini-3-pro-low', to: 'gemini-3-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' }, // Gemini 通配符映射 { label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' }, { label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, From 5bd7408b2f27778654b9a73a9904bc5aa8c8681a Mon Sep 17 00:00:00 2001 From: erio Date: Wed, 25 Feb 2026 00:10:07 +0800 Subject: [PATCH 23/24] fix: add fallback pricing for opus-4.6 and gemini-3.1-pro models --- backend/internal/service/billing_service.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index f100be0b..af29d614 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -133,6 +133,18 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok SupportsCacheBreakdown: false, } + + // Claude 4.6 Opus (与4.5同价) + s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"] + + // Gemini 3.1 Pro + s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{ + InputPricePerToken: 2e-6, // $2 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + CacheCreationPricePerToken: 2e-6, // $2 per MTok + CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok + SupportsCacheBreakdown: false, + } } // getFallbackPricing 根据模型系列获取回退价格 @@ -141,6 +153,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { // 按模型系列匹配 if strings.Contains(modelLower, "opus") { + if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") { + return s.fallbackPrices["claude-opus-4.6"] + } if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") { return s.fallbackPrices["claude-opus-4.5"] } @@ -158,6 +173,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } return s.fallbackPrices["claude-3-haiku"] } + if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") { + return s.fallbackPrices["gemini-3.1-pro"] + } // 默认使用Sonnet价格 return s.fallbackPrices["claude-sonnet-4"] From 58f21e4b3a633088e83bb03726499dfdc87596ac Mon Sep 17 00:00:00 2001 From: erio Date: Wed, 25 Feb 2026 00:23:37 +0800 Subject: [PATCH 24/24] fix: correct gofmt alignment in gemini-3.1-pro fallback pricing --- backend/internal/service/billing_service.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index af29d614..a523001c 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -139,9 +139,9 @@ func (s *BillingService) initFallbackPricing() { // Gemini 3.1 Pro s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{ - InputPricePerToken: 2e-6, // $2 per MTok - OutputPricePerToken: 12e-6, // $12 per MTok - CacheCreationPricePerToken: 2e-6, // $2 per MTok + InputPricePerToken: 2e-6, // $2 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + CacheCreationPricePerToken: 2e-6, // $2 per MTok CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok SupportsCacheBreakdown: false, }