Compare commits

..

41 Commits

Author SHA1 Message Date
shaw
ded9b6c14e fix: upgrade utls to v1.8.2 to resolve GO-2026-4512 vulnerability 2026-02-25 08:57:43 +08:00
Wesley Liddick
609abbbd7c Merge pull request #624 from cagedbird043/pr/antigravity-gemini31-passthrough-buttons
feat: 补充 Antigravity 的 Gemini 3.1 Pro 透传快捷按钮
2026-02-25 08:45:49 +08:00
Wesley Liddick
1b4e504fad Merge pull request #625 from cagedbird043/pr/antigravity-default-gemini31-passthrough
fix: 默认补全 Antigravity 的 Gemini 3.1 Pro 透传映射
2026-02-25 08:45:16 +08:00
Wesley Liddick
0a3a445828 Merge pull request #628 from cagedbird043/pr/docs-model-mapping-bulk-edit-tip
docs: 增加跨平台批量修改导致模型映射丢失的排障经验
2026-02-25 08:31:31 +08:00
Wesley Liddick
c7e18bd5be Merge pull request #627 from touwaeriol/pr/bugfixes-and-enhancements
feat: 反重力(Antigravity)增强、Failover 重构及新模型支持
2026-02-25 08:30:25 +08:00
cagedbird043
083d202fe4 docs: 增加跨平台批量修改导致模型映射丢失的排障经验 2026-02-25 01:02:25 +08:00
erio
8365a8328b merge: resolve conflicts with upstream/main (Gemini 3→3.1 mappings) 2026-02-25 00:38:39 +08:00
erio
58f21e4b3a fix: correct gofmt alignment in gemini-3.1-pro fallback pricing 2026-02-25 00:23:37 +08:00
erio
5bd7408b2f fix: add fallback pricing for opus-4.6 and gemini-3.1-pro models 2026-02-25 00:10:07 +08:00
erio
c671e8dd1d fix: 统一gemini-3默认映射为非强制3.1 2026-02-24 23:24:48 +08:00
cagedbird043
a3aed3c4c3 fix: 默认补全 antigravity 的 Gemini 3.1 Pro 透传映射 2026-02-24 22:54:11 +08:00
cagedbird043
c008649584 feat: 补充 antigravity 的 Gemini 3.1 Pro 透传快捷按钮 2026-02-24 22:53:53 +08:00
Wesley Liddick
516f8f287c Merge pull request #623 from cagedbird043/fix/antigravity-mapping-upgrade-additions
fix: 补全 Antigravity 模型映射升级与快捷按钮
2026-02-24 22:50:24 +08:00
Wesley Liddick
66148690c6 Merge pull request #622 from cagedbird043/fix/auto-clear-account-error-on-usage
fix: 刷新用量成功后自动清理账号可恢复错误状态
2026-02-24 22:49:08 +08:00
Wesley Liddick
cadd7f546f Merge pull request #621 from cagedbird043/fix/gemini-auth-url-613
fix: 修复 Gemini 授权链接生成失败(issue #613)
2026-02-24 22:48:09 +08:00
erio
a3ff317f1c feat: optimize model rate limit indicator layout with short aliases
- Change layout from fixed 3-column grid to vertical-first responsive
  columns (1 col for ≤4 items, 2 cols for ≤8, 3 cols for 9+)
- Add short aliases for all known model scope keys (e.g. COpus46, CSon46,
  G3PH, G3F) to reduce badge width
- Display countdown timer directly on each badge (supports h/m/s)
- Retain legacy scope aliases for backward compatibility
2026-02-24 22:11:50 +08:00
erio
d8d4b0c0c7 fix: enable Gemini model_mapping UI and extend warmup to Antigravity
- Remove Gemini platform exclusion from model restriction UI in
  Create/Edit account modals (Gemini now supports model_mapping)
- Remove outdated Gemini model passthrough info cards
- Add model_mapping field to GeminiCredentials type
- Extend warmup request interception toggle to Antigravity platform
- Remove redundant try/catch in API key account creation
- Remove noisy gateway.request_completed debug log
- Reorganize Gemini model mapping sections in constants.go
2026-02-24 21:30:32 +08:00
erio
d616f8c854 refactor: remove unused ClientSecret constant
The ClientSecret constant was left as an empty string after
getClientSecret() was refactored to use defaultClientSecret.
Remove the dead constant and update the test accordingly.
2026-02-24 21:09:46 +08:00
erio
b6fa8b8eec fix: update tests for defaultClientSecret and align migration 058
- Fix oauth_test.go and client_test.go to use defaultClientSecret
  variable instead of env var (init() already sets the default)
- Align migration 058 gemini-3-pro-high/low/preview mappings with
  constants.go (map to 3.1 versions)
2026-02-24 21:06:10 +08:00
erio
36d2e6999b feat: add default value for Antigravity OAuth client secret
Add a built-in default for ANTIGRAVITY_OAUTH_CLIENT_SECRET so the
service works out of the box without requiring environment variable
configuration. The env var can still override the default.
2026-02-24 20:54:28 +08:00
cagedbird043
076c00063d feat: 补全 antigravity 模型映射快捷按钮 2026-02-24 20:31:36 +08:00
cagedbird043
ea8104c6a2 fix: antigravity 默认补全 gemini-3-flash 透传 2026-02-24 20:31:36 +08:00
erio
ca3e9336e1 test: update UserAgent version assertion to match 1.18.4 default 2026-02-24 20:31:02 +08:00
erio
f92ab48166 fix: add gemini-3.1-pro-preview to default Antigravity model mapping
Add missing gemini-3.1-pro-preview -> gemini-3.1-pro-high mapping to
DefaultAntigravityModelMapping for consistency with migration 059.
2026-02-24 20:06:19 +08:00
cagedbird043
c10267ce2b fix: 刷新用量成功后自动清理账号可恢复错误状态 2026-02-24 20:04:36 +08:00
cagedbird043
9bd6a62ab3 test: 更新 Gemini OAuth 内置回退测试用例 2026-02-24 20:04:05 +08:00
cagedbird043
0dbea6ca58 fix: 修复 Gemini 授权链接生成失败并改进错误提示 2026-02-24 20:04:05 +08:00
erio
6523b23221 revert: remove backend-ci.yml changes (fork-specific CI config) 2026-02-24 19:45:23 +08:00
erio
29c406dda0 feat: add migrations for sonnet-4-6 and gemini-3.1-pro model mappings
Add migration 058 to update existing Antigravity accounts with
claude-sonnet-4-6 in model_mapping. Add migration 059 to add
gemini-3.1-pro-high/low/preview mappings.
2026-02-24 19:40:30 +08:00
erio
483c8f246d chore: update default Antigravity UserAgent version to 1.18.4
Update the default ANTIGRAVITY_USER_AGENT_VERSION from 1.84.2 to
1.18.4 to match the current Antigravity-Manager desktop client.
2026-02-24 19:39:15 +08:00
erio
645f283108 feat: add claude-sonnet-4-6 and gemini-3.1-pro model support
Add claude-sonnet-4-6 to identity injection modelInfoMap and
Antigravity model selector. Add gemini-3.1-pro-high/low to
Antigravity model list and Sonnet 4.6 preset mapping.
2026-02-24 19:30:01 +08:00
erio
da6fd45000 chore: add sonnet-4-6 mapping, config defaults, and CI improvements
- Add claude-sonnet-4-6 to default Antigravity model mapping
- Add antigravity_extra_retries default value in config
- Add cache-dependency-path to CI setup-go for faster builds
- Simplify vitest config to avoid vite plugin compatibility issues
2026-02-24 18:55:39 +08:00
erio
fb3ef5f388 fix(frontend): add Gemini models to bulk edit and fix status grid layout
Add Gemini model presets to BulkEditAccountModal for bulk model mapping.
Fix AccountStatusIndicator model rate limit grid layout using proper
grid container.
2026-02-24 18:55:25 +08:00
erio
86bc76e352 test: add warmup request interception unit tests
Add comprehensive tests for warmup request interception behavior
covering Antigravity accounts with various credential configurations.
2026-02-24 18:55:11 +08:00
erio
644058174e fix(gemini): enable model_mapping filtering for Gemini API Key accounts
Remove the special case that bypassed model-supported checks for Gemini
API Key accounts, allowing model_mapping to filter requests properly.
Add tests for multiplatform model filtering behavior.
2026-02-24 18:54:59 +08:00
erio
4573868c08 fix(antigravity): bill with mapped model and use final model key for rate limiting
- Use mapped model (billingModel) instead of original request model for billing
- Use resolveFinalAntigravityModelKey for 429 rate limit model key,
  ensuring rate limit records match the actual upstream model
- Add regression tests for both fixes
2026-02-24 18:08:19 +08:00
erio
09166a52f8 refactor: extract failover error handling into FailoverState
- Extract duplicated failover logic from gateway_handler.go (3 places)
  and gemini_v1beta_handler.go into shared failover_loop.go
- Introduce FailoverState with HandleFailoverError and HandleSelectionExhausted
- Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go
- Add comprehensive unit tests (32+ test cases)
- Delete redundant gateway_handler_single_account_retry_test.go
2026-02-24 18:08:04 +08:00
erio
aaac1aaca9 feat: add mixed-channel precheck API for account-group binding
Add a dedicated CheckMixedChannel endpoint that allows the frontend
to pre-validate mixed channel risk before submitting create/update
requests. This improves UX by showing warnings earlier in the flow
instead of only after form submission.

Backend changes:
- Add CheckMixedChannelRequest struct and CheckMixedChannel handler
- Register POST /check-mixed-channel route
- Expose CheckMixedChannelRisk as public method on AdminService
- Simplify Create/Update 409 responses (remove details/require_confirmation)
- Add comprehensive handler tests and stub methods

Frontend changes:
- Add checkMixedChannelRisk API function and TypeScript types
- Refactor CreateAccountModal to precheck before step transition and submission
- Refactor EditAccountModal to precheck before update submission
- Replace pendingPayload pattern with action-based dialog flow
2026-02-24 17:16:53 +08:00
erio
59898c16c6 fix: fix intercept_warmup_requests config not being saved
Extract applyInterceptWarmup utility to unify all credential building
call sites:
- Fix upstream account creation missing intercept_warmup_requests write
- Fix apikey edit mode missing else-branch to clear the setting
- Add backend unit test for IsInterceptWarmupEnabled
- Add frontend unit test for credentialsBuilder
2026-02-24 16:48:16 +08:00
erio
0dacdf480b fix: distinguish client disconnection from upstream retry failure
Before this change, when a client disconnected mid-request, the error
message was "Upstream request failed after retries", which is misleading
and pollutes error logs. Now we check context.Err() to return a more
accurate "Client disconnected" message for both Claude and Gemini
forward paths.
2026-02-24 16:45:08 +08:00
erio
fdf9f68298 fix: update Claude usage window to support 4.6 models
The usage progress bar only matched claude-sonnet-4-5 and
claude-opus-4-5-thinking. After upgrading to 4.6, the backend returns
claude-sonnet-4-6/claude-opus-4-6-thinking which didn't match,
causing the Claude usage bar to not display.

- Add claude-sonnet-4-6 and claude-opus-4-6-thinking to the match list
- Rename label from "C4.5" to "Claude" for future-proofing
2026-02-24 16:44:18 +08:00
50 changed files with 2816 additions and 685 deletions

View File

@@ -209,7 +209,30 @@ git add ent/ # 生成的文件也要提交
---
### 坑 10PR 提交前检查清单
### 坑 10前端测试看似正常,但后端调用失败(模型映射被批量误改)
**典型现象**
- 前端按钮点测看起来正常;
- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号;
- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。
**根因**
- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”;
- 但在**批量修改同时选中不同平台账号**OpenAI + Antigravity/Gemini模型白名单/映射可能被跨平台策略覆盖;
- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。
**修复方案(按优先级)**
1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。
2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。
**关键经验**
- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传;
- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案;
- 批量操作前尽量按平台分组,不要混选不同平台账号。
---
### 坑 11PR 提交前检查清单
提交 PR 前务必本地验证:

View File

@@ -18,7 +18,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/otp v1.5.0
github.com/redis/go-redis/v9 v9.17.2
github.com/refraction-networking/utls v1.8.1
github.com/refraction-networking/utls v1.8.2
github.com/robfig/cron/v3 v3.0.1
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/viper v1.18.2
@@ -79,7 +79,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -148,7 +147,6 @@ require (
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect

View File

@@ -120,8 +120,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -176,8 +174,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -211,8 +207,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -238,12 +232,10 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -266,8 +258,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -1158,6 +1158,7 @@ func setDefaults() {
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))

View File

@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
// Claude 详细版本 ID 映射
@@ -89,16 +90,18 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3.1-pro-high",
"gemini-3-pro-low": "gemini-3.1-pro-low",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
// Gemini 3.1 透传
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3 preview 映射
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3.1-pro-high",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// Gemini 3.1 白名单
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3.1 preview 映射
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",

View File

@@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type CheckMixedChannelRequest struct {
Platform string `json:"platform" binding:"required"`
GroupIDs []int64 `json:"group_ids"`
AccountID *int64 `json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*dto.Account
@@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// Create handles creating a new account
// POST /api/v1/admin/accounts
func (h *AccountHandler) Create(c *gin.Context) {
@@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
@@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}

View File

@@ -0,0 +1,147 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}

View File

@@ -10,19 +10,27 @@ import (
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
mu sync.Mutex
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))

View File

@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
if strings.Contains(msg, "OAuth client not configured") ||
strings.Contains(msg, "requires your own OAuth Client") ||
strings.Contains(msg, "requires a custom OAuth Client") ||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}

View File

@@ -0,0 +1,160 @@
package handler
import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
// GatewayService 隐式实现此接口。
type TempUnscheduler interface {
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
}
// FailoverAction 表示 failover 错误处理后的下一步动作
type FailoverAction int
const (
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue
FailoverContinue FailoverAction = iota
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
FailoverExhausted
// FailoverCanceled context 已取消(调用方应直接 return
FailoverCanceled
)
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s
// Handler 层只需短暂间隔后重新进入 Service 层即可。
singleAccountBackoffDelay = 2 * time.Second
)
// FailoverState 跨循环迭代共享的 failover 状态
type FailoverState struct {
SwitchCount int
MaxSwitches int
FailedAccountIDs map[int64]struct{}
SameAccountRetryCount map[int64]int
LastFailoverErr *service.UpstreamFailoverError
ForceCacheBilling bool
hasBoundSession bool
}
// NewFailoverState 创建 failover 状态
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
return &FailoverState{
MaxSwitches: maxSwitches,
FailedAccountIDs: make(map[int64]struct{}),
SameAccountRetryCount: make(map[int64]int),
hasBoundSession: hasBoundSession,
}
}
// HandleFailoverError 处理 UpstreamFailoverError返回下一步动作。
// 包含缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
func (s *FailoverState) HandleFailoverError(
ctx context.Context,
gatewayService TempUnscheduler,
accountID int64,
platform string,
failoverErr *service.UpstreamFailoverError,
) FailoverAction {
s.LastFailoverErr = failoverErr
// 缓存计费判断
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
s.ForceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
s.SameAccountRetryCount[accountID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
if !sleepWithContext(ctx, sameAccountRetryDelay) {
return FailoverCanceled
}
return FailoverContinue
}
// 同账号重试用尽,执行临时封禁
if failoverErr.RetryableOnSameAccount {
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
}
// 加入失败列表
s.FailedAccountIDs[accountID] = struct{}{}
// 检查是否耗尽
if s.SwitchCount >= s.MaxSwitches {
return FailoverExhausted
}
// 递增切换计数
s.SwitchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d",
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
// Antigravity 平台换号线性递增延时
if platform == service.PlatformAntigravity {
delay := time.Duration(s.SwitchCount-1) * time.Second
if !sleepWithContext(ctx, delay) {
return FailoverCanceled
}
}
return FailoverContinue
}
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
// 清除排除列表、等待退避后重新选号。
//
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
// 返回 FailoverExhausted 时,调用方应返回错误响应。
// 返回 FailoverCanceled 时,调用方应直接 return。
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
if s.LastFailoverErr != nil &&
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
s.SwitchCount <= s.MaxSwitches {
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
singleAccountBackoffDelay, s.SwitchCount)
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
return FailoverCanceled
}
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
s.SwitchCount, s.MaxSwitches)
s.FailedAccountIDs = make(map[int64]struct{})
return FailoverContinue
}
return FailoverExhausted
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
func sleepWithContext(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}

View File

@@ -0,0 +1,732 @@
package handler
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mock
// ---------------------------------------------------------------------------
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
type mockTempUnscheduler struct {
calls []tempUnscheduleCall
}
type tempUnscheduleCall struct {
accountID int64
failoverErr *service.UpstreamFailoverError
}
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
}
// ---------------------------------------------------------------------------
// Helper
// ---------------------------------------------------------------------------
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
return &service.UpstreamFailoverError{
StatusCode: statusCode,
RetryableOnSameAccount: retryable,
ForceCacheBilling: forceBilling,
}
}
// ---------------------------------------------------------------------------
// NewFailoverState 测试
// ---------------------------------------------------------------------------
func TestNewFailoverState(t *testing.T) {
t.Run("初始化字段正确", func(t *testing.T) {
fs := NewFailoverState(5, true)
require.Equal(t, 5, fs.MaxSwitches)
require.Equal(t, 0, fs.SwitchCount)
require.NotNil(t, fs.FailedAccountIDs)
require.Empty(t, fs.FailedAccountIDs)
require.NotNil(t, fs.SameAccountRetryCount)
require.Empty(t, fs.SameAccountRetryCount)
require.Nil(t, fs.LastFailoverErr)
require.False(t, fs.ForceCacheBilling)
require.True(t, fs.hasBoundSession)
})
t.Run("无绑定会话", func(t *testing.T) {
fs := NewFailoverState(3, false)
require.Equal(t, 3, fs.MaxSwitches)
require.False(t, fs.hasBoundSession)
})
t.Run("零最大切换次数", func(t *testing.T) {
fs := NewFailoverState(0, false)
require.Equal(t, 0, fs.MaxSwitches)
})
}
// ---------------------------------------------------------------------------
// sleepWithContext 测试
// ---------------------------------------------------------------------------
func TestSleepWithContext(t *testing.T) {
t.Run("零时长立即返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), 0)
require.True(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("负时长立即返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), -1*time.Second)
require.True(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("正常等待后返回true", func(t *testing.T) {
start := time.Now()
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
elapsed := time.Since(start)
require.True(t, ok)
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
require.Less(t, elapsed, 500*time.Millisecond)
})
t.Run("已取消context立即返回false", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := sleepWithContext(ctx, 5*time.Second)
require.False(t, ok)
require.Less(t, time.Since(start), 50*time.Millisecond)
})
t.Run("等待期间context取消返回false", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := sleepWithContext(ctx, 5*time.Second)
elapsed := time.Since(start)
require.False(t, ok)
require.Less(t, elapsed, 500*time.Millisecond)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 基本切换流程
// ---------------------------------------------------------------------------
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
require.Equal(t, err, fs.LastFailoverErr)
require.False(t, fs.ForceCacheBilling)
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
})
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
// switchCount 从 0→1 时sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
})
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
// switchCount 从 1→2 时sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1 // 模拟已切换一次
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
require.Less(t, elapsed, 3*time.Second)
})
t.Run("连续切换直到耗尽", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(2, false)
// 第一次切换0→1
err1 := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
// 第二次切换1→2
err2 := newTestFailoverErr(502, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
// 第三次已耗尽SwitchCount(2) >= MaxSwitches(2)
err3 := newTestFailoverErr(503, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
require.Equal(t, FailoverExhausted, action)
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
// 验证失败账号列表
require.Len(t, fs.FailedAccountIDs, 3)
require.Contains(t, fs.FailedAccountIDs, int64(100))
require.Contains(t, fs.FailedAccountIDs, int64(200))
require.Contains(t, fs.FailedAccountIDs, int64(300))
// LastFailoverErr 应为最后一次的错误
require.Equal(t, err3, fs.LastFailoverErr)
})
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(0, false)
err := newTestFailoverErr(500, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverExhausted, action)
require.Equal(t, 0, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
// ---------------------------------------------------------------------------
func TestHandleFailoverError_CacheBilling(t *testing.T) {
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
err := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.True(t, fs.ForceCacheBilling)
})
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.True(t, fs.ForceCacheBilling)
})
t.Run("两者均为false时不设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.False(t, fs.ForceCacheBilling)
})
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
// 第一次ForceCacheBilling=true → 设置
err1 := newTestFailoverErr(500, false, true)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.True(t, fs.ForceCacheBilling)
// 第二次ForceCacheBilling=false → 仍然保持 true
err2 := newTestFailoverErr(502, false, false)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
// ---------------------------------------------------------------------------
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
// 验证等待了 sameAccountRetryDelay (500ms)
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
require.Less(t, elapsed, 2*time.Second)
})
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
// 第二次
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
})
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次、第二次重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Contains(t, fs.FailedAccountIDs, int64(100))
// 验证 TempUnschedule 被调用
require.Len(t, mock.calls, 1)
require.Equal(t, int64(100), mock.calls[0].accountID)
require.Equal(t, err, mock.calls[0].failoverErr)
})
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
err := newTestFailoverErr(400, true, false)
// 账号 100 第一次重试
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
// 账号 200 第一次重试(独立计数)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[200])
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
})
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
err := newTestFailoverErr(400, true, false)
// 耗尽账号 100 的重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
// 第三次: 重试耗尽 → 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
// 再次遇到账号 100计数仍为 2条件不满足 → 直接切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — TempUnschedule 调用验证
// ---------------------------------------------------------------------------
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Empty(t, mock.calls)
})
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(502, true, false)
// 耗尽重试
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
require.Len(t, mock.calls, 1)
require.Equal(t, int64(42), mock.calls[0].accountID)
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — Context 取消
// ---------------------------------------------------------------------------
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
// 重试计数仍应递增
require.Equal(t, 1, fs.SameAccountRetryCount[100])
})
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
err := newTestFailoverErr(500, false, false)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — FailedAccountIDs 跟踪
// ---------------------------------------------------------------------------
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
t.Run("切换时添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Contains(t, fs.FailedAccountIDs, int64(100))
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
require.Contains(t, fs.FailedAccountIDs, int64(200))
require.Len(t, fs.FailedAccountIDs, 2)
})
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(0, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Equal(t, FailoverExhausted, action)
require.Contains(t, fs.FailedAccountIDs, int64(100))
})
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
require.Equal(t, FailoverContinue, action)
require.NotContains(t, fs.FailedAccountIDs, int64(100))
})
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(5, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — LastFailoverErr 更新
// ---------------------------------------------------------------------------
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err1 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.Equal(t, err1, fs.LastFailoverErr)
err2 := newTestFailoverErr(502, false, false)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.Equal(t, err2, fs.LastFailoverErr)
})
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, err, fs.LastFailoverErr)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 综合集成场景
// ---------------------------------------------------------------------------
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
retryErr := newTestFailoverErr(400, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Len(t, mock.calls, 1)
// 3. 账号 200 遇到不可重试错误 → 直接切换
switchErr := newTestFailoverErr(500, false, false)
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SwitchCount)
// 4. 账号 300 遇到不可重试错误 → 再切换
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 3, fs.SwitchCount)
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
require.Equal(t, FailoverExhausted, action)
// 最终状态验证
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
require.True(t, fs.ForceCacheBilling)
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
})
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(2, false)
err := newTestFailoverErr(500, false, false)
// 第一次切换delay = 0s
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
// 第二次切换delay = 1s
start = time.Now()
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
elapsed = time.Since(start)
require.Equal(t, FailoverContinue, action)
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
start = time.Now()
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
elapsed = time.Since(start)
require.Equal(t, FailoverExhausted, action)
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
})
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false) // hasBoundSession=false
// 第一次ForceCacheBilling=false
err1 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
require.False(t, fs.ForceCacheBilling)
// 第二次ForceCacheBilling=trueAntigravity 粘性会话切换)
err2 := newTestFailoverErr(500, false, true)
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
// 第三次ForceCacheBilling=false但状态仍保持 true
err3 := newTestFailoverErr(500, false, false)
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
require.True(t, fs.ForceCacheBilling, "不应重置")
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 边界条件
// ---------------------------------------------------------------------------
func TestHandleFailoverError_EdgeCases(t *testing.T) {
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(0, false, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
})
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[0])
})
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(500, true, false)
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
})
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
fs.SwitchCount = 1
err := newTestFailoverErr(500, false, false)
start := time.Now()
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
})
}
// ---------------------------------------------------------------------------
// HandleSelectionExhausted 测试
// ---------------------------------------------------------------------------
func TestHandleSelectionExhausted(t *testing.T) {
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(3, false)
// LastFailoverErr 为 nil
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverExhausted, action)
})
t.Run("非503错误返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverExhausted, action)
})
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.FailedAccountIDs[100] = struct{}{}
fs.SwitchCount = 1
start := time.Now()
action := fs.HandleSelectionExhausted(context.Background())
elapsed := time.Since(start)
require.Equal(t, FailoverContinue, action)
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
require.Less(t, elapsed, 5*time.Second)
})
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
fs := NewFailoverState(2, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.SwitchCount = 3 // > MaxSwitches(2)
start := time.Now()
action := fs.HandleSelectionExhausted(context.Background())
elapsed := time.Since(start)
require.Equal(t, FailoverExhausted, action)
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
})
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
fs := NewFailoverState(3, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
action := fs.HandleSelectionExhausted(ctx)
elapsed := time.Since(start)
require.Equal(t, FailoverCanceled, action)
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
})
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
fs := NewFailoverState(2, false)
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
fs.SwitchCount = 2 // == MaxSwitches条件是 <=,仍可重试
action := fs.HandleSelectionExhausted(context.Background())
require.Equal(t, FailoverContinue, action)
})
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
if fs.LastFailoverErr != nil {
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
return
case FailoverCanceled:
return
}
switchCount++
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: forceCacheBilling,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
reqLog.Warn("gateway.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
if fs.LastFailoverErr != nil {
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
case FailoverExhausted:
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
return
case FailoverCanceled:
return
}
switchCount++
reqLog.Warn("gateway.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: forceCacheBilling,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
).Error("gateway.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Bool("fallback_used", fallbackUsed),
)
return
}
if !retryWithFallback {
@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503MODEL_CAPACITY_EXHAUSTED时使用
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
// 固定短延时2s
// Service 层已经在原地等待了足够长的时间retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
logger.L().With(
zap.String("component", "handler.gateway.failover"),
zap.Duration("delay", delay),
zap.Int("retry_count", retryCount),
).Info("gateway.single_account_backoff_waiting")
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody

View File

@@ -1,51 +0,0 @@
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.True(t, ok, "should return true when context is not canceled")
// 固定延迟 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
}
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.False(t, ok, "should return false when context is canceled")
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
}
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
// 验证不同 retryCount 都使用固定 2s 延迟
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
elapsed := time.Since(start)
require.True(t, ok)
// 即使 retryCount=5延迟仍然是固定的 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
require.Less(t, elapsed, 5*time.Second)
}

View File

@@ -0,0 +1,340 @@
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 目标严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
type fakeSchedulerCache struct {
accounts []*service.Account
}
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
return f.accounts, true, nil
}
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
return nil
}
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
return nil, nil
}
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
return nil
}
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
type fakeGroupRepo struct {
group *service.Group
}
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
return f.group, nil
}
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
return nil, nil
}
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
return nil
}
type fakeConcurrencyCache struct{}
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
return 0, nil
}
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
return true, nil
}
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper()
schedulerCache := &fakeSchedulerCache{accounts: accounts}
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
gwSvc := service.NewGatewayService(
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache (disable sticky)
nil, // cfg
schedulerSnapshot,
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
nil, // billingService
nil, // rateLimitService
nil, // billingCacheService
nil, // identityService
nil, // httpUpstream
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
nil, // digestStore
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
h := &GatewayHandler{
gatewayService: gwSvc,
billingCacheService: billingCacheSvc,
concurrencyHelper: concurrencyHelper,
// 这些字段对本测试不敏感,保持较小即可
maxAccountSwitches: 1,
maxAccountSwitchesGemini: 1,
}
cleanup := func() {
billingCacheSvc.Stop()
}
return h, cleanup
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2001)
accountID := int64(1001)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAnthropic, // /v1/messagesClaude兼容入口
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-1",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Extra: map[string]any{
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
c.Request = req
apiKey := &service.APIKey{
ID: 3001,
UserID: 4001,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4001,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "New Conversation", first["text"])
}
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(2002)
accountID := int64(1002)
group := &service.Group{
ID: groupID,
Hydrated: true,
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
}
account := &service.Account{
ID: accountID,
Name: "ag-2",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "tok_xxx",
"intercept_warmup_requests": true,
},
Concurrency: 1,
Priority: 1,
Status: service.StatusActive,
Schedulable: true,
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
}
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
defer cleanup()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`)
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
// - 写入 request.ContextService读取
// - 写入 gin.ContextHandler快速读取
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
req = req.WithContext(ctx)
c.Request = req
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
apiKey := &service.APIKey{
ID: 3002,
UserID: 4002,
GroupID: &groupID,
Status: service.StatusActive,
User: &service.User{
ID: 4002,
Concurrency: 10,
Balance: 100,
},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
h.Messages(c)
require.Equal(t, 200, rec.Code)
selected, ok := c.Get(opsAccountIDKey)
require.True(t, ok)
require.Equal(t, accountID, selected)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "msg_mock_warmup", resp["id"])
require.Equal(t, "claude-sonnet-4-5", resp["model"])
}

View File

@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
reqLog.Warn("gemini.single_account_retrying",
zap.Int("retry_count", switchCount),
zap.Int("max_retries", maxAccountSwitches),
)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
return
default: // FailoverExhausted
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
}
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch failoverAction {
case FailoverContinue:
continue
case FailoverExhausted:
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
return
case FailoverCanceled:
return
}
lastFailoverErr = failoverErr
switchCount++
reqLog.Warn("gemini.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: forceCacheBilling,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
})
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
zap.Int("switch_count", fs.SwitchCount),
)
return
}

View File

@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_成功(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求方法
@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) {
}
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("")
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
}
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_MockServer(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) {
}
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("")
_, err := client.RefreshToken(context.Background(), "refresh-tok")
@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client {
// ---------------------------------------------------------------------------
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
}
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
}
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
}
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second) // 模拟慢响应
@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
// ---------------------------------------------------------------------------
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
}
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
}
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
}
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
old := defaultClientSecret
defaultClientSecret = "test-secret"
t.Cleanup(func() { defaultClientSecret = old })
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)

View File

@@ -23,11 +23,9 @@ const (
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = ""
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
// 出于安全原因,该值不得硬编码入库。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri用户需手动复制 code
@@ -51,14 +49,21 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2
var defaultUserAgentVersion = "1.84.2"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4
var defaultUserAgentVersion = "1.18.4"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
func init() {
// 从环境变量读取版本号,未设置则使用默认值
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
defaultUserAgentVersion = version
}
// 从环境变量读取 client_secret未设置则使用默认值
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
defaultClientSecret = secret
}
}
// GetUserAgent 返回当前配置的 User-Agent
@@ -67,14 +72,9 @@ func GetUserAgent() string {
}
func getClientSecret() (string, error) {
if v := strings.TrimSpace(ClientSecret); v != "" {
if v := strings.TrimSpace(defaultClientSecret); v != "" {
return v, nil
}
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
if vv := strings.TrimSpace(v); vv != "" {
return vv, nil
}
}
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/hex"
"net/url"
"os"
"strings"
"testing"
"time"
@@ -17,8 +18,14 @@ import (
// ---------------------------------------------------------------------------
func TestGetClientSecret_环境变量设置(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
// 需要重新触发 init 逻辑:手动从环境变量读取
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) {
}
func TestGetClientSecret_环境变量为空(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量为空时应返回错误")
t.Fatal("defaultClientSecret 为空时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) {
}
func TestGetClientSecret_环境变量未设置(t *testing.T) {
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
// 明确设置再取消,确保环境变量不存在
t.Setenv(AntigravityOAuthClientSecretEnv, "")
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量未设置时应返回错误")
t.Fatal("defaultClientSecret 为空时应返回错误")
}
}
func TestGetClientSecret_环境变量含空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
old := defaultClientSecret
defaultClientSecret = " "
t.Cleanup(func() { defaultClientSecret = old })
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量仅含空格时应返回错误")
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
}
}
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
old := defaultClientSecret
defaultClientSecret = " valid-secret "
t.Cleanup(func() { defaultClientSecret = old })
secret, err := getClientSecret()
if err != nil {
@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) {
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
t.Errorf("ClientID 不匹配: got %s", ClientID)
}
if ClientSecret != "" {
t.Error("ClientSecret 应为空字符串")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
}
if secret != "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" {
t.Errorf("默认 client_secret 不匹配: got %s", secret)
}
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if GetUserAgent() != "antigravity/1.84.2 windows/amd64" {
if GetUserAgent() != "antigravity/1.18.4 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {

View File

@@ -206,6 +206,7 @@ type modelInfo struct {
var modelInfoMap = map[string]modelInfo{
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
}

View File

@@ -38,10 +38,8 @@ const (
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
// GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
GeminiCLIOAuthClientSecret = ""
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"

View File

@@ -408,11 +408,10 @@ func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
}
}
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
// 不设置环境变量,也不提供 client 凭据EffectiveOAuthConfig 应该报错
func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
_, err := BuildAuthorizationURL(
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
@@ -420,8 +419,11 @@ func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
"",
"code_assist",
)
if err == nil {
t.Error("当 EffectiveOAuthConfig 失败时BuildAuthorizationURL 应该返回错误")
if err != nil {
t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err)
}
if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) {
t.Errorf("应使用内置 Gemini CLI client_id实际 URL: %s", authURL)
}
}
@@ -685,15 +687,17 @@ func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
}
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
// 不设置环境变量且不提供凭据,应该报错
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
if err == nil {
t.Error("没有内置 secret 且未提供凭据时应该报错")
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
if err != nil {
t.Fatalf("不设置环境变量时应回退到内置 secret实际报错: %v", err)
}
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
t.Errorf("错误消息应提及环境变量 %s实际: %v", GeminiCLIOAuthClientSecretEnv, err)
if strings.TrimSpace(cfg.ClientSecret) == "" {
t.Error("ClientSecret 不应为空")
}
if cfg.ClientID != GeminiCLIOAuthClientID {
t.Errorf("ClientID 应回退为内置客户端 ID实际: %q", cfg.ClientID)
}
}

View File

@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)

View File

@@ -372,6 +372,13 @@ func (a *Account) GetModelMapping() map[string]string {
}
}
if len(result) > 0 {
if a.Platform == domain.PlatformAntigravity {
ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3-flash",
"gemini-3.1-pro-high",
"gemini-3.1-pro-low",
})
}
return result
}
}
@@ -382,6 +389,27 @@ func (a *Account) GetModelMapping() map[string]string {
return nil
}
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" {
return
}
if _, exists := mapping[model]; exists {
return
}
for pattern := range mapping {
if matchWildcard(pattern, model) {
return
}
}
mapping[model] = model
}
func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []string) {
for _, model := range models {
ensureAntigravityDefaultPassthrough(mapping, model)
}
}
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping返回 true允许所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {

View File

@@ -0,0 +1,66 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
expected bool
}{
{
name: "nil credentials",
credentials: nil,
expected: false,
},
{
name: "empty map",
credentials: map[string]any{},
expected: false,
},
{
name: "field not present",
credentials: map[string]any{"access_token": "tok"},
expected: false,
},
{
name: "field is true",
credentials: map[string]any{"intercept_warmup_requests": true},
expected: true,
},
{
name: "field is false",
credentials: map[string]any{"intercept_warmup_requests": false},
expected: false,
},
{
name: "field is string true",
credentials: map[string]any{"intercept_warmup_requests": "true"},
expected: false,
},
{
name: "field is int 1",
credentials: map[string]any{"intercept_warmup_requests": 1},
expected: false,
},
{
name: "field is nil",
credentials: map[string]any{"intercept_warmup_requests": nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Credentials: tt.credentials}
result := a.IsInterceptWarmupEnabled()
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"strings"
"sync"
"time"
@@ -217,12 +218,20 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
}
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
usage, err := s.getGeminiUsage(ctx, account)
if err == nil {
s.tryClearRecoverableAccountError(ctx, account)
}
return usage, err
}
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity {
return s.getAntigravityUsage(ctx, account)
usage, err := s.getAntigravityUsage(ctx, account)
if err == nil {
s.tryClearRecoverableAccountError(ctx, account)
}
return usage, err
}
// 只有oauth类型账号可以通过API获取usage有profile scope
@@ -256,6 +265,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
// 4. 添加窗口统计有独立缓存1 分钟)
s.addWindowStats(ctx, account, usage)
s.tryClearRecoverableAccountError(ctx, account)
return usage, nil
}
@@ -486,6 +496,32 @@ func parseTime(s string) (time.Time, error) {
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
}
func (s *AccountUsageService) tryClearRecoverableAccountError(ctx context.Context, account *Account) {
if account == nil || account.Status != StatusError {
return
}
msg := strings.ToLower(strings.TrimSpace(account.ErrorMessage))
if msg == "" {
return
}
if !strings.Contains(msg, "token refresh failed") &&
!strings.Contains(msg, "invalid_client") &&
!strings.Contains(msg, "missing_project_id") &&
!strings.Contains(msg, "unauthenticated") {
return
}
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
log.Printf("[usage] failed to clear recoverable account error for account %d: %v", account.ID, err)
return
}
account.Status = StatusActive
account.ErrorMessage = ""
}
// buildUsageInfo 构建UsageInfo
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
info := &UsageInfo{

View File

@@ -267,3 +267,50 @@ func TestAccountGetMappedModel(t *testing.T) {
})
}
}
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3-pro-high": "gemini-3.1-pro-high",
},
},
}
mapping := account.GetModelMapping()
if mapping["gemini-3-flash"] != "gemini-3-flash" {
t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"])
}
if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" {
t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"])
}
if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" {
t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"])
}
}
func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3*": "gemini-3.1-pro-high",
},
},
}
mapping := account.GetModelMapping()
if _, exists := mapping["gemini-3-flash"]; exists {
t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists")
}
if _, exists := mapping["gemini-3.1-pro-high"]; exists {
t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists")
}
if _, exists := mapping["gemini-3.1-pro-low"]; exists {
t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists")
}
if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" {
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
}
}

View File

@@ -54,6 +54,7 @@ type AdminService interface {
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
}
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 {
return

View File

@@ -87,7 +87,6 @@ var (
)
const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
billingModel := mappedModel
// 获取 access_token
if s.tokenProvider == nil {
@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
ForceCacheBilling: switchErr.IsStickySession,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
resp := result.resp
@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Model: billingModel, // 使用映射模型用于计费和日志
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if mappedModel == "" {
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
}
billingModel := mappedModel
// 获取 access_token
if s.tokenProvider == nil {
@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
ForceCacheBilling: switchErr.IsStickySession,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
resp := result.resp
@@ -2197,7 +2206,7 @@ handleSuccess:
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Model: billingModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
defaultDur := s.getDefaultRateLimitDuration()
// 尝试解析模型 key 并设置模型级限流
modelKey := resolveAntigravityModelKey(requestedModel)
//
// 注意requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
// 因此这里必须写入最终模型 key确保后续调度能正确避开已限流模型。
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
if strings.TrimSpace(modelKey) == "" {
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
modelKey = resolveAntigravityModelKey(requestedModel)
}
if modelKey != "" {
ra := s.resolveResetTime(resetAt, defaultDur)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
@@ -3881,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
@@ -3934,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
Model: originalModel,
}, nil
}
@@ -3975,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Model: originalModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,

View File

@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return s.resp, s.err
}
type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
return "", ErrSettingNotFound
}
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
account := &Account{
@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
// 验证Antigravity Claude 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hello"},
},
"max_tokens": 16,
"stream": true,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 5,
Name: "acc-forward-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"claude-sonnet-4-5": mappedModel,
},
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
// 验证Antigravity Gemini 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 6,
Name: "acc-gemini-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"gemini-2.5-flash": mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {

View File

@@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
},
// 3. 默认映射中的透传(映射到自己)
{
name: "默认映射透传 - claude-sonnet-4-6",
requestedModel: "claude-sonnet-4-6",
accountMapping: nil,
expected: "claude-sonnet-4-6",
},
{
name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",

View File

@@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景
// 验证requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
require.Nil(t, result)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {

View File

@@ -133,6 +133,18 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
SupportsCacheBreakdown: false,
}
// Claude 4.6 Opus (与4.5同价)
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
CacheCreationPricePerToken: 2e-6, // $2 per MTok
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
SupportsCacheBreakdown: false,
}
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -141,6 +153,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
// 按模型系列匹配
if strings.Contains(modelLower, "opus") {
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
return s.fallbackPrices["claude-opus-4.6"]
}
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
return s.fallbackPrices["claude-opus-4.5"]
}
@@ -158,6 +173,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
return s.fallbackPrices["claude-3-haiku"]
}
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
return s.fallbackPrices["gemini-3.1-pro"]
}
// 默认使用Sonnet价格
return s.fallbackPrices["claude-sonnet-4"]

View File

@@ -895,6 +895,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t
require.Equal(t, int64(2), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}},
},
{
ID: 2,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Priority: 2,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号")
acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(50)
@@ -1070,6 +1119,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
},
},
model: "gemini-2.5-flash",
expected: false,
},
{
name: "Gemini平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
},
},
model: "gemini-2.5-pro",
expected: true,
},
}
for _, tt := range tests {

View File

@@ -2825,10 +2825,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
return true
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}

View File

@@ -107,12 +107,12 @@ func TestIsModelRateLimited(t *testing.T) {
expected: true,
},
{
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3.1-pro-high",
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-3.1-pro-high": map[string]any{
"gemini-3-pro-high": map[string]any{
"rate_limit_reset_at": future,
},
},

View File

@@ -0,0 +1,42 @@
-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts
--
-- Background:
-- Antigravity now supports claude-sonnet-4-6
--
-- Strategy:
-- Directly overwrite the entire model_mapping with updated mappings
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -0,0 +1,45 @@
-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping
--
-- Background:
-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low
--
-- Strategy:
-- Directly overwrite the entire model_mapping with updated mappings
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -15,7 +15,9 @@ import type {
AccountUsageStatsResponse,
TempUnschedulableStatus,
AdminDataPayload,
AdminDataImportResult
AdminDataImportResult,
CheckMixedChannelRequest,
CheckMixedChannelResponse
} from '@/types'
/**
@@ -133,6 +135,16 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise
return data
}
/**
* Check mixed-channel risk for account-group binding.
*/
export async function checkMixedChannelRisk(
payload: CheckMixedChannelRequest
): Promise<CheckMixedChannelResponse> {
const { data } = await apiClient.post<CheckMixedChannelResponse>('/admin/accounts/check-mixed-channel', payload)
return data
}
/**
* Delete account
* @param id - Account ID
@@ -535,6 +547,7 @@ export const accountsAPI = {
getById,
create,
update,
checkMixedChannelRisk,
delete: deleteAccount,
toggleStatus,
testAccount,

View File

@@ -77,13 +77,23 @@
</div>
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
<template v-if="activeModelRateLimits.length > 0">
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
<div
v-if="activeModelRateLimits.length > 0"
:class="[
activeModelRateLimits.length <= 4
? 'flex flex-col gap-1'
: activeModelRateLimits.length <= 8
? 'columns-2 gap-x-2'
: 'columns-3 gap-x-2'
]"
>
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative mb-1 break-inside-avoid">
<span
class="inline-flex items-center gap-1 rounded bg-purple-100 px-1.5 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/30 dark:text-purple-400"
>
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
{{ formatScopeName(item.model) }}
<span class="text-[10px] opacity-70">{{ formatModelResetTime(item.reset_at) }}</span>
</span>
<!-- Tooltip -->
<div
@@ -95,7 +105,7 @@
></div>
</div>
</div>
</template>
</div>
<!-- Overload Indicator (529) -->
<div v-if="isOverloaded" class="group relative">
@@ -154,17 +164,50 @@ const activeModelRateLimits = computed(() => {
})
const formatScopeName = (scope: string): string => {
const names: Record<string, string> = {
const aliases: Record<string, string> = {
// Claude 系列
'claude-opus-4-6-thinking': 'COpus46',
'claude-sonnet-4-6': 'CSon46',
'claude-sonnet-4-5': 'CSon45',
'claude-sonnet-4-5-thinking': 'CSon45T',
// Gemini 2.5 系列
'gemini-2.5-flash': 'G25F',
'gemini-2.5-flash-lite': 'G25FL',
'gemini-2.5-flash-thinking': 'G25FT',
'gemini-2.5-pro': 'G25P',
// Gemini 3 系列
'gemini-3-flash': 'G3F',
'gemini-3.1-pro-high': 'G3PH',
'gemini-3.1-pro-low': 'G3PL',
'gemini-3-pro-image': 'G3PI',
// 其他
'gpt-oss-120b-medium': 'GPT120',
'tab_flash_lite_preview': 'TabFL',
// 旧版 scope 别名(兼容)
claude: 'Claude',
claude_sonnet: 'Claude Sonnet',
claude_opus: 'Claude Opus',
claude_haiku: 'Claude Haiku',
claude_sonnet: 'CSon',
claude_opus: 'COpus',
claude_haiku: 'CHaiku',
gemini_text: 'Gemini',
gemini_image: 'Image',
gemini_flash: 'Gemini Flash',
gemini_pro: 'Gemini Pro'
gemini_image: 'GImg',
gemini_flash: 'GFlash',
gemini_pro: 'GPro',
}
return names[scope] || scope
return aliases[scope] || scope
}
const formatModelResetTime = (resetAt: string): string => {
const date = new Date(resetAt)
const now = new Date()
const diffMs = date.getTime() - now.getTime()
if (diffMs <= 0) return ''
const totalSecs = Math.floor(diffMs / 1000)
const h = Math.floor(totalSecs / 3600)
const m = Math.floor((totalSecs % 3600) / 60)
const s = totalSecs % 60
if (h > 0) return `${h}h${m}m`
if (m > 0) return `${m}m${s}s`
return `${s}s`
}
// Computed: is overloaded (529)

View File

@@ -172,12 +172,12 @@
color="purple"
/>
<!-- Claude 4.5 -->
<!-- Claude -->
<UsageProgressBar
v-if="antigravityClaude45UsageFromAPI !== null"
:label="t('admin.accounts.usageWindow.claude45')"
:utilization="antigravityClaude45UsageFromAPI.utilization"
:resets-at="antigravityClaude45UsageFromAPI.resetTime"
v-if="antigravityClaudeUsageFromAPI !== null"
:label="t('admin.accounts.usageWindow.claude')"
:utilization="antigravityClaudeUsageFromAPI.utilization"
:resets-at="antigravityClaudeUsageFromAPI.resetTime"
color="amber"
/>
</div>
@@ -400,9 +400,12 @@ const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(
// Gemini 3 Image from API
const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image']))
// Claude 4.5 from API
const antigravityClaude45UsageFromAPI = computed(() =>
getAntigravityUsageFromAPI(['claude-sonnet-4-5', 'claude-opus-4-5-thinking'])
// Claude from API (all Claude model variants)
const antigravityClaudeUsageFromAPI = computed(() =>
getAntigravityUsageFromAPI([
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
'claude-sonnet-4-6', 'claude-opus-4-6-thinking',
])
)
// Antigravity 账户类型(从 load_code_assist 响应中提取)

View File

@@ -209,7 +209,7 @@
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div
v-for="(mapping, index) in modelMappings"
:key="getModelMappingKey(mapping)"
:key="index"
class="flex items-center gap-2"
>
<input
@@ -654,7 +654,7 @@ import Select from '@/components/common/Select.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import Icon from '@/components/icons/Icon.vue'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist'
interface Props {
show: boolean
@@ -696,7 +696,6 @@ const baseUrl = ref('')
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([])
const modelMappings = ref<ModelMapping[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
@@ -707,7 +706,7 @@ const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
// All models list (combined Anthropic + OpenAI)
// All models list (combined Anthropic + OpenAI + Gemini)
const allModels = [
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
@@ -719,17 +718,21 @@ const allModels = [
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' }
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
]
// Preset mappings (combined Anthropic + OpenAI)
// Preset mappings (combined Anthropic + OpenAI + Gemini)
const presetMappings = [
{
label: 'Sonnet 4',
@@ -765,18 +768,37 @@ const presetMappings = [
color:
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
},
{
label: 'Sonnet4→4.6',
from: 'claude-sonnet-4-20250514',
to: 'claude-sonnet-4-6',
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
},
{
label: 'Sonnet4.5→4.6',
from: 'claude-sonnet-4-5-20250929',
to: 'claude-sonnet-4-6',
color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400'
},
{
label: 'Sonnet3.5→4.6',
from: 'claude-3-5-sonnet-20241022',
to: 'claude-sonnet-4-6',
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
},
{
label: 'Opus4.5→4.6',
from: 'claude-opus-4-5-20251101',
to: 'claude-opus-4-6-thinking',
color:
'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400'
},
{
label: 'Opus->Sonnet',
from: 'claude-opus-4-5-20251101',
to: 'claude-sonnet-4-5-20250929',
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
},
{
label: 'GPT-5.3 Codex Spark',
from: 'gpt-5.3-codex-spark',
to: 'gpt-5.3-codex-spark',
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
},
{
label: 'GPT-5.2',
from: 'gpt-5.2-2025-12-11',
@@ -794,6 +816,36 @@ const presetMappings = [
from: 'gpt-5.1-codex-max',
to: 'gpt-5.1-codex',
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400'
},
{
label: '3-Pro-Preview→3.1-Pro-High',
from: 'gemini-3-pro-preview',
to: 'gemini-3.1-pro-high',
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
},
{
label: '3-Pro-High→3.1-Pro-High',
from: 'gemini-3-pro-high',
to: 'gemini-3.1-pro-high',
color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400'
},
{
label: '3-Pro-Low→3.1-Pro-Low',
from: 'gemini-3-pro-low',
to: 'gemini-3.1-pro-low',
color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400'
},
{
label: '3-Flash透传',
from: 'gemini-3-flash',
to: 'gemini-3-flash',
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
},
{
label: '2.5-Flash-Lite透传',
from: 'gemini-2.5-flash-lite',
to: 'gemini-2.5-flash-lite',
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
}
]
@@ -883,23 +935,11 @@ const removeErrorCode = (code: number) => {
}
const buildModelMappingObject = (): Record<string, string> | null => {
const mapping: Record<string, string> = {}
if (modelRestrictionMode.value === 'whitelist') {
for (const model of allowedModels.value) {
mapping[model] = model
}
} else {
for (const m of modelMappings.value) {
const from = m.from.trim()
const to = m.to.trim()
if (from && to) {
mapping[from] = to
}
}
}
return Object.keys(mapping).length > 0 ? mapping : null
return buildModelMappingPayload(
modelRestrictionMode.value,
allowedModels.value,
modelMappings.value
)
}
const buildUpdatePayload = (): Record<string, unknown> | null => {

View File

@@ -916,8 +916,8 @@
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
</div>
<!-- Model Restriction Section (不适用于 GeminiAntigravity 已在上层条件排除) -->
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<!-- Model Restriction Section (Antigravity 已在上层条件排除) -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div
@@ -1200,34 +1200,6 @@
</div>
</div>
<!-- Gemini 模型说明 -->
<div v-if="form.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
<div class="flex items-start gap-3">
<svg
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<div>
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
{{ t('admin.accounts.gemini.modelPassthrough') }}
</p>
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
</p>
</div>
</div>
</div>
</div>
</div>
<!-- Temp Unschedulable Rules -->
@@ -1378,9 +1350,9 @@
</div>
</div>
<!-- Intercept Warmup Requests (Anthropic only) -->
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
<div
v-if="form.platform === 'anthropic'"
v-if="form.platform === 'anthropic' || form.platform === 'antigravity'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="flex items-center justify-between">
@@ -2157,7 +2129,7 @@
<ConfirmDialog
:show="showMixedChannelWarning"
:title="t('admin.accounts.mixedChannelWarningTitle')"
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
:message="mixedChannelWarningMessageText"
:confirm-text="t('common.confirm')"
:cancel-text="t('common.cancel')"
:danger="true"
@@ -2189,13 +2161,21 @@ import {
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
import type { Proxy, AdminGroup, AccountPlatform, AccountType } from '@/types'
import type {
Proxy,
AdminGroup,
AccountPlatform,
AccountType,
CheckMixedChannelResponse,
CreateAccountRequest
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
@@ -2337,10 +2317,13 @@ const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false)
// Mixed channel warning dialog state
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
const pendingCreatePayload = ref<any>(null)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
)
const mixedChannelWarningRawMessage = ref('')
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
const antigravityMixedChannelConfirmed = ref(false)
const showAdvancedOAuth = ref(false)
const showGeminiHelpDialog = ref(false)
@@ -2378,6 +2361,13 @@ const isOpenAIModelRestrictionDisabled = computed(() =>
form.platform === 'openai' && openaiPassthroughEnabled.value
)
const mixedChannelWarningMessageText = computed(() => {
if (mixedChannelWarningDetails.value) {
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
}
return mixedChannelWarningRawMessage.value
})
const geminiQuotaDocs = {
codeAssist: 'https://developers.google.com/gemini-code-assist/resources/quotas',
aiStudio: 'https://ai.google.dev/pricing',
@@ -2544,8 +2534,8 @@ watch(
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
}
// Reset Anthropic-specific settings when switching to other platforms
if (newPlatform !== 'anthropic') {
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false
}
if (newPlatform === 'sora') {
@@ -2794,6 +2784,105 @@ const splitTempUnschedKeywords = (value: string) => {
.filter((item) => item.length > 0)
}
const needsMixedChannelCheck = (platform: AccountPlatform) => platform === 'antigravity' || platform === 'anthropic'
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
const details = resp?.details
if (!details) {
return null
}
return {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
}
const clearMixedChannelDialog = () => {
showMixedChannelWarning.value = false
mixedChannelWarningDetails.value = null
mixedChannelWarningRawMessage.value = ''
mixedChannelWarningAction.value = null
}
const openMixedChannelDialog = (opts: {
response?: CheckMixedChannelResponse
message?: string
onConfirm: () => Promise<void>
}) => {
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
mixedChannelWarningRawMessage.value =
opts.message || opts.response?.message || t('admin.accounts.failedToCreate')
mixedChannelWarningAction.value = opts.onConfirm
showMixedChannelWarning.value = true
}
const withAntigravityConfirmFlag = (payload: CreateAccountRequest): CreateAccountRequest => {
if (needsMixedChannelCheck(payload.platform) && antigravityMixedChannelConfirmed.value) {
return {
...payload,
confirm_mixed_channel_risk: true
}
}
const cloned = { ...payload }
delete cloned.confirm_mixed_channel_risk
return cloned
}
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
if (!needsMixedChannelCheck(form.platform)) {
return true
}
if (antigravityMixedChannelConfirmed.value) {
return true
}
try {
const result = await adminAPI.accounts.checkMixedChannelRisk({
platform: form.platform,
group_ids: form.group_ids
})
if (!result.has_risk) {
return true
}
openMixedChannelDialog({
response: result,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await onConfirm()
}
})
return false
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
return false
}
}
const submitCreateAccount = async (payload: CreateAccountRequest) => {
submitting.value = true
try {
await adminAPI.accounts.create(withAntigravityConfirmFlag(payload))
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
} catch (error: any) {
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck(form.platform)) {
openMixedChannelDialog({
message: error.response?.data?.message,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await submitCreateAccount(payload)
}
})
return
}
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
submitting.value = false
}
}
// Methods
const resetForm = () => {
step.value = 1
@@ -2855,9 +2944,13 @@ const resetForm = () => {
geminiOAuth.resetState()
antigravityOAuth.resetState()
oauthFlowRef.value?.reset()
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
}
const handleClose = () => {
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
emit('close')
}
@@ -2916,56 +3009,34 @@ const buildSoraExtra = (
}
// Helper function to create account with mixed channel warning handling
const doCreateAccount = async (payload: any) => {
const doCreateAccount = async (payload: CreateAccountRequest) => {
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
await submitCreateAccount(payload)
})
if (!canContinue) {
return
}
await submitCreateAccount(payload)
}
// Handle mixed channel warning confirmation
const handleMixedChannelConfirm = async () => {
const action = mixedChannelWarningAction.value
if (!action) {
clearMixedChannelDialog()
return
}
clearMixedChannelDialog()
submitting.value = true
try {
await adminAPI.accounts.create(payload)
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
} catch (error: any) {
// Handle 409 mixed_channel_warning - show confirmation dialog
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
const details = error.response.data.details || {}
mixedChannelWarningDetails.value = {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
pendingCreatePayload.value = payload
showMixedChannelWarning.value = true
} else {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
}
await action()
} finally {
submitting.value = false
}
}
// Handle mixed channel warning confirmation
const handleMixedChannelConfirm = async () => {
showMixedChannelWarning.value = false
if (pendingCreatePayload.value) {
pendingCreatePayload.value.confirm_mixed_channel_risk = true
submitting.value = true
try {
await adminAPI.accounts.create(pendingCreatePayload.value)
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
submitting.value = false
pendingCreatePayload.value = null
}
}
}
const handleMixedChannelCancel = () => {
showMixedChannelWarning.value = false
pendingCreatePayload.value = null
mixedChannelWarningDetails.value = null
clearMixedChannelDialog()
}
const handleSubmit = async () => {
@@ -2975,6 +3046,12 @@ const handleSubmit = async () => {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
step.value = 2
})
if (!canContinue) {
return
}
step.value = 2
return
}
@@ -3010,15 +3087,10 @@ const handleSubmit = async () => {
credentials.model_mapping = antigravityModelMapping
}
submitting.value = true
try {
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
submitting.value = false
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
return
}
@@ -3059,10 +3131,7 @@ const handleSubmit = async () => {
credentials.custom_error_codes = [...selectedErrorCodes.value]
}
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
credentials.intercept_warmup_requests = true
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
if (!applyTempUnschedConfig(credentials)) {
return
}
@@ -3132,7 +3201,7 @@ const createAccountAndFinish = async (
if (!applyTempUnschedConfig(credentials)) {
return
}
await adminAPI.accounts.create({
await doCreateAccount({
name: form.name,
notes: form.notes,
platform,
@@ -3147,9 +3216,6 @@ const createAccountAndFinish = async (
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
appStore.showSuccess(t('admin.accounts.accountCreated'))
emit('created')
handleClose()
}
// OpenAI OAuth 授权码兑换
@@ -3497,7 +3563,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
await adminAPI.accounts.create({
const createPayload = withAntigravityConfirmFlag({
name: accountName,
notes: form.notes,
platform: 'antigravity',
@@ -3512,6 +3578,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
await adminAPI.accounts.create(createPayload)
successCount++
} catch (error: any) {
failedCount++
@@ -3606,6 +3673,7 @@ const handleAntigravityExchange = async (authCode: string) => {
if (!tokenInfo) return
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
// Antigravity 只使用映射模式
const antigravityModelMapping = buildModelMappingObject(
'mapping',
@@ -3677,10 +3745,8 @@ const handleAnthropicExchange = async (authCode: string) => {
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
}
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
const credentials: Record<string, unknown> = { ...tokenInfo }
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
} catch (error: any) {
oauth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
@@ -3779,11 +3845,8 @@ const handleCookieAuth = async (sessionKey: string) => {
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials
const credentials: Record<string, unknown> = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
}
const credentials: Record<string, unknown> = { ...tokenInfo }
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
if (tempUnschedEnabled.value) {
credentials.temp_unschedulable_enabled = true
credentials.temp_unschedulable_rules = tempUnschedPayload

View File

@@ -65,8 +65,8 @@
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
</div>
<!-- Model Restriction Section (不适用于 Gemini Antigravity) -->
<div v-if="account.platform !== 'gemini' && account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<!-- Model Restriction Section (不适用于 Antigravity) -->
<div v-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
<div
@@ -349,34 +349,6 @@
</div>
</div>
<!-- Gemini 模型说明 -->
<div v-if="account.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
<div class="flex items-start gap-3">
<svg
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<div>
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
{{ t('admin.accounts.gemini.modelPassthrough') }}
</p>
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
</p>
</div>
</div>
</div>
</div>
</div>
<!-- Upstream fields (only for upstream type) -->
@@ -641,9 +613,9 @@
</div>
</div>
<!-- Intercept Warmup Requests (Anthropic only) -->
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
<div
v-if="account?.platform === 'anthropic'"
v-if="account?.platform === 'anthropic' || account?.platform === 'antigravity'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="flex items-center justify-between">
@@ -1139,7 +1111,7 @@
<ConfirmDialog
:show="showMixedChannelWarning"
:title="t('admin.accounts.mixedChannelWarningTitle')"
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
:message="mixedChannelWarningMessageText"
:confirm-text="t('common.confirm')"
:cancel-text="t('common.cancel')"
:danger="true"
@@ -1154,7 +1126,7 @@ import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import type { Account, Proxy, AdminGroup } from '@/types'
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
@@ -1162,6 +1134,7 @@ import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
@@ -1233,10 +1206,13 @@ const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-mod
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
// Mixed channel warning dialog state
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
const pendingUpdatePayload = ref<Record<string, unknown> | null>(null)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
)
const mixedChannelWarningRawMessage = ref('')
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
const antigravityMixedChannelConfirmed = ref(false)
// Quota control state (Anthropic OAuth/SetupToken only)
const windowCostEnabled = ref(false)
@@ -1297,6 +1273,13 @@ const defaultBaseUrl = computed(() => {
return 'https://api.anthropic.com'
})
const mixedChannelWarningMessageText = computed(() => {
if (mixedChannelWarningDetails.value) {
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
}
return mixedChannelWarningRawMessage.value
})
const form = reactive({
name: '',
notes: '',
@@ -1326,6 +1309,11 @@ watch(
() => props.account,
(newAccount) => {
if (newAccount) {
antigravityMixedChannelConfirmed.value = false
showMixedChannelWarning.value = false
mixedChannelWarningDetails.value = null
mixedChannelWarningRawMessage.value = ''
mixedChannelWarningAction.value = null
form.name = newAccount.name
form.notes = newAccount.notes || ''
form.proxy_id = newAccount.proxy_id
@@ -1725,18 +1713,123 @@ function toPositiveNumber(value: unknown) {
return Math.trunc(num)
}
const needsMixedChannelCheck = () => props.account?.platform === 'antigravity' || props.account?.platform === 'anthropic'
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
const details = resp?.details
if (!details) {
return null
}
return {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
}
const clearMixedChannelDialog = () => {
showMixedChannelWarning.value = false
mixedChannelWarningDetails.value = null
mixedChannelWarningRawMessage.value = ''
mixedChannelWarningAction.value = null
}
const openMixedChannelDialog = (opts: {
response?: CheckMixedChannelResponse
message?: string
onConfirm: () => Promise<void>
}) => {
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
mixedChannelWarningRawMessage.value =
opts.message || opts.response?.message || t('admin.accounts.failedToUpdate')
mixedChannelWarningAction.value = opts.onConfirm
showMixedChannelWarning.value = true
}
const withAntigravityConfirmFlag = (payload: Record<string, unknown>) => {
if (needsMixedChannelCheck() && antigravityMixedChannelConfirmed.value) {
return {
...payload,
confirm_mixed_channel_risk: true
}
}
const cloned = { ...payload }
delete cloned.confirm_mixed_channel_risk
return cloned
}
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
if (!needsMixedChannelCheck()) {
return true
}
if (antigravityMixedChannelConfirmed.value) {
return true
}
if (!props.account) {
return false
}
try {
const result = await adminAPI.accounts.checkMixedChannelRisk({
platform: props.account.platform,
group_ids: form.group_ids,
account_id: props.account.id
})
if (!result.has_risk) {
return true
}
openMixedChannelDialog({
response: result,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await onConfirm()
}
})
return false
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
return false
}
}
const formatDateTimeLocal = formatDateTimeLocalInput
const parseDateTimeLocal = parseDateTimeLocalInput
// Methods
const handleClose = () => {
antigravityMixedChannelConfirmed.value = false
clearMixedChannelDialog()
emit('close')
}
const submitUpdateAccount = async (accountID: number, updatePayload: Record<string, unknown>) => {
submitting.value = true
try {
const updatedAccount = await adminAPI.accounts.update(accountID, withAntigravityConfirmFlag(updatePayload))
appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated', updatedAccount)
handleClose()
} catch (error: any) {
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck()) {
openMixedChannelDialog({
message: error.response?.data?.message,
onConfirm: async () => {
antigravityMixedChannelConfirmed.value = true
await submitUpdateAccount(accountID, updatePayload)
}
})
return
}
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
} finally {
submitting.value = false
}
}
const handleSubmit = async () => {
if (!props.account) return
const accountID = props.account.id
submitting.value = true
const updatePayload: Record<string, unknown> = { ...form }
try {
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
@@ -1768,7 +1861,6 @@ const handleSubmit = async () => {
newCredentials.api_key = currentCredentials.api_key
} else {
appStore.showError(t('admin.accounts.apiKeyIsRequired'))
submitting.value = false
return
}
@@ -1789,11 +1881,8 @@ const handleSubmit = async () => {
}
// Add intercept warmup requests setting
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
@@ -1808,8 +1897,10 @@ const handleSubmit = async () => {
newCredentials.api_key = editApiKey.value.trim()
}
// Add intercept warmup requests setting
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
@@ -1819,13 +1910,8 @@ const handleSubmit = async () => {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
if (interceptWarmupRequests.value) {
newCredentials.intercept_warmup_requests = true
} else {
delete newCredentials.intercept_warmup_requests
}
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
@@ -1955,52 +2041,36 @@ const handleSubmit = async () => {
updatePayload.extra = newExtra
}
const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated', updatedAccount)
handleClose()
} catch (error: any) {
// Handle 409 mixed_channel_warning - show confirmation dialog
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
const details = error.response.data.details || {}
mixedChannelWarningDetails.value = {
groupName: details.group_name || 'Unknown',
currentPlatform: details.current_platform || 'Unknown',
otherPlatform: details.other_platform || 'Unknown'
}
pendingUpdatePayload.value = updatePayload
showMixedChannelWarning.value = true
} else {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
await submitUpdateAccount(accountID, updatePayload)
})
if (!canContinue) {
return
}
} finally {
submitting.value = false
await submitUpdateAccount(accountID, updatePayload)
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
}
}
// Handle mixed channel warning confirmation
const handleMixedChannelConfirm = async () => {
showMixedChannelWarning.value = false
if (pendingUpdatePayload.value && props.account) {
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
submitting.value = true
try {
const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated', updatedAccount)
handleClose()
} catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
} finally {
submitting.value = false
pendingUpdatePayload.value = null
}
const action = mixedChannelWarningAction.value
if (!action) {
clearMixedChannelDialog()
return
}
clearMixedChannelDialog()
submitting.value = true
try {
await action()
} finally {
submitting.value = false
}
}
const handleMixedChannelCancel = () => {
showMixedChannelWarning.value = false
pendingUpdatePayload.value = null
mixedChannelWarningDetails.value = null
clearMixedChannelDialog()
}
</script>

View File

@@ -0,0 +1,46 @@
import { describe, it, expect } from 'vitest'
import { applyInterceptWarmup } from '../credentialsBuilder'
describe('applyInterceptWarmup', () => {
it('create + enabled=true: should set intercept_warmup_requests to true', () => {
const creds: Record<string, unknown> = { access_token: 'tok' }
applyInterceptWarmup(creds, true, 'create')
expect(creds.intercept_warmup_requests).toBe(true)
})
it('create + enabled=false: should not add the field', () => {
const creds: Record<string, unknown> = { access_token: 'tok' }
applyInterceptWarmup(creds, false, 'create')
expect('intercept_warmup_requests' in creds).toBe(false)
})
it('edit + enabled=true: should set intercept_warmup_requests to true', () => {
const creds: Record<string, unknown> = { api_key: 'sk' }
applyInterceptWarmup(creds, true, 'edit')
expect(creds.intercept_warmup_requests).toBe(true)
})
it('edit + enabled=false + field exists: should delete the field', () => {
const creds: Record<string, unknown> = { api_key: 'sk', intercept_warmup_requests: true }
applyInterceptWarmup(creds, false, 'edit')
expect('intercept_warmup_requests' in creds).toBe(false)
})
it('edit + enabled=false + field absent: should not throw', () => {
const creds: Record<string, unknown> = { api_key: 'sk' }
applyInterceptWarmup(creds, false, 'edit')
expect('intercept_warmup_requests' in creds).toBe(false)
})
it('should not affect other fields', () => {
const creds: Record<string, unknown> = {
api_key: 'sk',
base_url: 'url',
intercept_warmup_requests: true
}
applyInterceptWarmup(creds, false, 'edit')
expect(creds.api_key).toBe('sk')
expect(creds.base_url).toBe('url')
expect('intercept_warmup_requests' in creds).toBe(false)
})
})

View File

@@ -0,0 +1,11 @@
export function applyInterceptWarmup(
credentials: Record<string, unknown>,
enabled: boolean,
mode: 'create' | 'edit'
): void {
if (enabled) {
credentials.intercept_warmup_requests = true
} else if (mode === 'edit') {
delete credentials.intercept_warmup_requests
}
}

View File

@@ -76,6 +76,7 @@ const antigravityModels = [
// Claude 4.5+ 系列
'claude-opus-4-6',
'claude-opus-4-5-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-5',
'claude-sonnet-4-5-thinking',
// Gemini 2.5 系列
@@ -88,6 +89,9 @@ const antigravityModels = [
'gemini-3-pro-high',
'gemini-3-pro-low',
'gemini-3-pro-image',
// Gemini 3.1 系列
'gemini-3.1-pro-high',
'gemini-3.1-pro-low',
// 其他
'gpt-oss-120b-medium',
'tab_flash_lite_preview'
@@ -287,14 +291,23 @@ const antigravityPresetMappings = [
{ label: 'Sonnet→Sonnet', from: 'claude-sonnet-*', to: 'claude-sonnet-4-5', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
{ label: 'Opus→Opus', from: 'claude-opus-*', to: 'claude-opus-4-6-thinking', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Haiku→Sonnet', from: 'claude-haiku-*', to: 'claude-sonnet-4-5', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
{ label: 'Sonnet4→4.6', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-6', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
{ label: 'Sonnet4.5→4.6', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Sonnet3.5→4.6', from: 'claude-3-5-sonnet-20241022', to: 'claude-sonnet-4-6', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' },
{ label: 'Opus4.5→4.6', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-6-thinking', color: 'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400' },
// Gemini 3→3.1 映射
{ label: '3-Pro-Preview→3.1-Pro-High', from: 'gemini-3-pro-preview', to: 'gemini-3.1-pro-high', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' },
{ label: '3-Pro-High→3.1-Pro-High', from: 'gemini-3-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
{ label: '3-Pro-Low→3.1-Pro-Low', from: 'gemini-3-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
{ label: '3.1-Pro-High透传', from: 'gemini-3.1-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
{ label: '3.1-Pro-Low透传', from: 'gemini-3.1-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
// Gemini 通配符映射
{ label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
{ label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
{ label: '3-Flash透传', from: 'gemini-3-flash', to: 'gemini-3-flash', color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' },
{ label: '2.5-Flash-Lite透传', from: 'gemini-2.5-flash-lite', to: 'gemini-2.5-flash-lite', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
// 精确映射
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
]

View File

@@ -2047,7 +2047,7 @@ export default {
gemini3Pro: 'G3P',
gemini3Flash: 'G3F',
gemini3Image: 'G3I',
claude45: 'C4.5'
claude: 'Claude'
},
tier: {
free: 'Free',

View File

@@ -1583,7 +1583,7 @@ export default {
gemini3Pro: 'G3P',
gemini3Flash: 'G3F',
gemini3Image: 'G3I',
claude45: 'C4.5'
claude: 'Claude'
},
tier: {
free: 'Free',

View File

@@ -581,6 +581,7 @@ export interface GeminiCredentials {
token_type?: string
scope?: string
expires_at?: string
model_mapping?: Record<string, string>
}
export interface TempUnschedulableRule {
@@ -766,6 +767,26 @@ export interface UpdateAccountRequest {
confirm_mixed_channel_risk?: boolean
}
export interface CheckMixedChannelRequest {
platform: AccountPlatform
group_ids: number[]
account_id?: number
}
export interface MixedChannelWarningDetails {
group_id: number
group_name: string
current_platform: string
other_platform: string
}
export interface CheckMixedChannelResponse {
has_risk: boolean
error?: string
message?: string
details?: MixedChannelWarningDetails
}
export interface CreateProxyRequest {
name: string
protocol: ProxyProtocol

View File

@@ -1,18 +1,13 @@
import { defineConfig } from 'vitest/config'
import vue from '@vitejs/plugin-vue'
import { resolve } from 'path'
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': resolve(__dirname, 'src'),
'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js'
}
},
define: {
__INTLIFY_JIT_COMPILATION__: true
},
test: {
globals: true,
environment: 'jsdom',
@@ -37,8 +32,6 @@ export default defineConfig({
lines: 80
}
}
},
setupFiles: ['./src/__tests__/setup.ts'],
testTimeout: 10000
}
}
})