mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-05 16:00:21 +08:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ded9b6c14e | ||
|
|
609abbbd7c | ||
|
|
1b4e504fad | ||
|
|
0a3a445828 | ||
|
|
c7e18bd5be | ||
|
|
083d202fe4 | ||
|
|
8365a8328b | ||
|
|
58f21e4b3a | ||
|
|
5bd7408b2f | ||
|
|
c671e8dd1d | ||
|
|
a3aed3c4c3 | ||
|
|
c008649584 | ||
|
|
516f8f287c | ||
|
|
66148690c6 | ||
|
|
cadd7f546f | ||
|
|
a3ff317f1c | ||
|
|
d8d4b0c0c7 | ||
|
|
d616f8c854 | ||
|
|
b6fa8b8eec | ||
|
|
36d2e6999b | ||
|
|
076c00063d | ||
|
|
ea8104c6a2 | ||
|
|
ca3e9336e1 | ||
|
|
f92ab48166 | ||
|
|
c10267ce2b | ||
|
|
9bd6a62ab3 | ||
|
|
0dbea6ca58 | ||
|
|
6523b23221 | ||
|
|
29c406dda0 | ||
|
|
483c8f246d | ||
|
|
645f283108 | ||
|
|
da6fd45000 | ||
|
|
fb3ef5f388 | ||
|
|
86bc76e352 | ||
|
|
644058174e | ||
|
|
4573868c08 | ||
|
|
09166a52f8 | ||
|
|
aaac1aaca9 | ||
|
|
59898c16c6 | ||
|
|
0dacdf480b | ||
|
|
fdf9f68298 |
25
DEV_GUIDE.md
25
DEV_GUIDE.md
@@ -209,7 +209,30 @@ git add ent/ # 生成的文件也要提交
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 坑 10:PR 提交前检查清单
|
### 坑 10:前端测试看似正常,但后端调用失败(模型映射被批量误改)
|
||||||
|
|
||||||
|
**典型现象**:
|
||||||
|
- 前端按钮点测看起来正常;
|
||||||
|
- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号;
|
||||||
|
- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。
|
||||||
|
|
||||||
|
**根因**:
|
||||||
|
- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”;
|
||||||
|
- 但在**批量修改同时选中不同平台账号**(OpenAI + Antigravity/Gemini)时,模型白名单/映射可能被跨平台策略覆盖;
|
||||||
|
- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。
|
||||||
|
|
||||||
|
**修复方案(按优先级)**:
|
||||||
|
1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。
|
||||||
|
2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。
|
||||||
|
|
||||||
|
**关键经验**:
|
||||||
|
- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传;
|
||||||
|
- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案;
|
||||||
|
- 批量操作前尽量按平台分组,不要混选不同平台账号。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 坑 11:PR 提交前检查清单
|
||||||
|
|
||||||
提交 PR 前务必本地验证:
|
提交 PR 前务必本地验证:
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ require (
|
|||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/pquerna/otp v1.5.0
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/redis/go-redis/v9 v9.17.2
|
github.com/redis/go-redis/v9 v9.17.2
|
||||||
github.com/refraction-networking/utls v1.8.1
|
github.com/refraction-networking/utls v1.8.2
|
||||||
github.com/robfig/cron/v3 v3.0.1
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6
|
github.com/shirou/gopsutil/v4 v4.25.6
|
||||||
github.com/spf13/viper v1.18.2
|
github.com/spf13/viper v1.18.2
|
||||||
@@ -79,7 +79,6 @@ require (
|
|||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
github.com/google/subcommands v1.2.0 // indirect
|
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||||
@@ -148,7 +147,6 @@ require (
|
|||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.31.0 // indirect
|
||||||
golang.org/x/sys v0.40.0 // indirect
|
golang.org/x/sys v0.40.0 // indirect
|
||||||
golang.org/x/text v0.33.0 // indirect
|
golang.org/x/text v0.33.0 // indirect
|
||||||
golang.org/x/tools v0.40.0 // indirect
|
|
||||||
google.golang.org/grpc v1.75.1 // indirect
|
google.golang.org/grpc v1.75.1 // indirect
|
||||||
google.golang.org/protobuf v1.36.10 // indirect
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
|
|||||||
@@ -120,8 +120,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
|
||||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
@@ -176,8 +174,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
|
||||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@@ -211,8 +207,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@@ -238,12 +232,10 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1
|
|||||||
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
|
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
|
||||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
|
||||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
@@ -266,8 +258,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
|
||||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
|
|||||||
@@ -1158,6 +1158,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.force_codex_cli", false)
|
viper.SetDefault("gateway.force_codex_cli", false)
|
||||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||||
|
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
|
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
|
||||||
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
|
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
|
||||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
|
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||||
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||||
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||||
// Claude 详细版本 ID 映射
|
// Claude 详细版本 ID 映射
|
||||||
@@ -89,16 +90,18 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
// Gemini 3 白名单
|
// Gemini 3 白名单
|
||||||
"gemini-3-flash": "gemini-3-flash",
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
"gemini-3-pro-high": "gemini-3.1-pro-high",
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
"gemini-3-pro-low": "gemini-3.1-pro-low",
|
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||||
"gemini-3-pro-image": "gemini-3-pro-image",
|
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||||
// Gemini 3.1 透传
|
|
||||||
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
|
||||||
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
|
||||||
// Gemini 3 preview 映射
|
// Gemini 3 preview 映射
|
||||||
"gemini-3-flash-preview": "gemini-3-flash",
|
"gemini-3-flash-preview": "gemini-3-flash",
|
||||||
"gemini-3-pro-preview": "gemini-3.1-pro-high",
|
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||||
|
// Gemini 3.1 白名单
|
||||||
|
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||||
|
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||||
|
// Gemini 3.1 preview 映射
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||||
// 其他官方模型
|
// 其他官方模型
|
||||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
|
|||||||
@@ -139,6 +139,13 @@ type BulkUpdateAccountsRequest struct {
|
|||||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||||
|
type CheckMixedChannelRequest struct {
|
||||||
|
Platform string `json:"platform" binding:"required"`
|
||||||
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
|
AccountID *int64 `json:"account_id"`
|
||||||
|
}
|
||||||
|
|
||||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||||
type AccountWithConcurrency struct {
|
type AccountWithConcurrency struct {
|
||||||
*dto.Account
|
*dto.Account
|
||||||
@@ -389,6 +396,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
|||||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
|
||||||
|
// POST /api/v1/admin/accounts/check-mixed-channel
|
||||||
|
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
|
||||||
|
var req CheckMixedChannelRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.GroupIDs) == 0 {
|
||||||
|
response.Success(c, gin.H{"has_risk": false})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := int64(0)
|
||||||
|
if req.AccountID != nil {
|
||||||
|
accountID = *req.AccountID
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
|
||||||
|
if err != nil {
|
||||||
|
var mixedErr *service.MixedChannelError
|
||||||
|
if errors.As(err, &mixedErr) {
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"has_risk": true,
|
||||||
|
"error": "mixed_channel_warning",
|
||||||
|
"message": mixedErr.Error(),
|
||||||
|
"details": gin.H{
|
||||||
|
"group_id": mixedErr.GroupID,
|
||||||
|
"group_name": mixedErr.GroupName,
|
||||||
|
"current_platform": mixedErr.CurrentPlatform,
|
||||||
|
"other_platform": mixedErr.OtherPlatform,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"has_risk": false})
|
||||||
|
}
|
||||||
|
|
||||||
// Create handles creating a new account
|
// Create handles creating a new account
|
||||||
// POST /api/v1/admin/accounts
|
// POST /api/v1/admin/accounts
|
||||||
func (h *AccountHandler) Create(c *gin.Context) {
|
func (h *AccountHandler) Create(c *gin.Context) {
|
||||||
@@ -431,17 +482,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
|||||||
// 检查是否为混合渠道错误
|
// 检查是否为混合渠道错误
|
||||||
var mixedErr *service.MixedChannelError
|
var mixedErr *service.MixedChannelError
|
||||||
if errors.As(err, &mixedErr) {
|
if errors.As(err, &mixedErr) {
|
||||||
// 返回特殊错误码要求确认
|
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||||
c.JSON(409, gin.H{
|
c.JSON(409, gin.H{
|
||||||
"error": "mixed_channel_warning",
|
"error": "mixed_channel_warning",
|
||||||
"message": mixedErr.Error(),
|
"message": mixedErr.Error(),
|
||||||
"details": gin.H{
|
|
||||||
"group_id": mixedErr.GroupID,
|
|
||||||
"group_name": mixedErr.GroupName,
|
|
||||||
"current_platform": mixedErr.CurrentPlatform,
|
|
||||||
"other_platform": mixedErr.OtherPlatform,
|
|
||||||
},
|
|
||||||
"require_confirmation": true,
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -501,17 +545,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
|||||||
// 检查是否为混合渠道错误
|
// 检查是否为混合渠道错误
|
||||||
var mixedErr *service.MixedChannelError
|
var mixedErr *service.MixedChannelError
|
||||||
if errors.As(err, &mixedErr) {
|
if errors.As(err, &mixedErr) {
|
||||||
// 返回特殊错误码要求确认
|
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||||
c.JSON(409, gin.H{
|
c.JSON(409, gin.H{
|
||||||
"error": "mixed_channel_warning",
|
"error": "mixed_channel_warning",
|
||||||
"message": mixedErr.Error(),
|
"message": mixedErr.Error(),
|
||||||
"details": gin.H{
|
|
||||||
"group_id": mixedErr.GroupID,
|
|
||||||
"group_name": mixedErr.GroupName,
|
|
||||||
"current_platform": mixedErr.CurrentPlatform,
|
|
||||||
"other_platform": mixedErr.OtherPlatform,
|
|
||||||
},
|
|
||||||
"require_confirmation": true,
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,147 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
|
||||||
|
router.POST("/api/v1/admin/accounts", accountHandler.Create)
|
||||||
|
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
|
||||||
|
adminSvc := newStubAdminService()
|
||||||
|
router := setupAccountMixedChannelRouter(adminSvc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"platform": "antigravity",
|
||||||
|
"group_ids": []int64{27},
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, float64(0), resp["code"])
|
||||||
|
data, ok := resp["data"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, false, data["has_risk"])
|
||||||
|
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
|
||||||
|
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
|
||||||
|
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
|
||||||
|
adminSvc := newStubAdminService()
|
||||||
|
adminSvc.checkMixedErr = &service.MixedChannelError{
|
||||||
|
GroupID: 27,
|
||||||
|
GroupName: "claude-max",
|
||||||
|
CurrentPlatform: "Antigravity",
|
||||||
|
OtherPlatform: "Anthropic",
|
||||||
|
}
|
||||||
|
router := setupAccountMixedChannelRouter(adminSvc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"platform": "antigravity",
|
||||||
|
"group_ids": []int64{27},
|
||||||
|
"account_id": 99,
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, float64(0), resp["code"])
|
||||||
|
data, ok := resp["data"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, true, data["has_risk"])
|
||||||
|
require.Equal(t, "mixed_channel_warning", data["error"])
|
||||||
|
details, ok := data["details"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, float64(27), details["group_id"])
|
||||||
|
require.Equal(t, "claude-max", details["group_name"])
|
||||||
|
require.Equal(t, "Antigravity", details["current_platform"])
|
||||||
|
require.Equal(t, "Anthropic", details["other_platform"])
|
||||||
|
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||||
|
adminSvc := newStubAdminService()
|
||||||
|
adminSvc.createAccountErr = &service.MixedChannelError{
|
||||||
|
GroupID: 27,
|
||||||
|
GroupName: "claude-max",
|
||||||
|
CurrentPlatform: "Antigravity",
|
||||||
|
OtherPlatform: "Anthropic",
|
||||||
|
}
|
||||||
|
router := setupAccountMixedChannelRouter(adminSvc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"name": "ag-oauth-1",
|
||||||
|
"platform": "antigravity",
|
||||||
|
"type": "oauth",
|
||||||
|
"credentials": map[string]any{"refresh_token": "rt"},
|
||||||
|
"group_ids": []int64{27},
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusConflict, rec.Code)
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
|
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||||
|
_, hasDetails := resp["details"]
|
||||||
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
|
require.False(t, hasDetails)
|
||||||
|
require.False(t, hasRequireConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||||
|
adminSvc := newStubAdminService()
|
||||||
|
adminSvc.updateAccountErr = &service.MixedChannelError{
|
||||||
|
GroupID: 27,
|
||||||
|
GroupName: "claude-max",
|
||||||
|
CurrentPlatform: "Antigravity",
|
||||||
|
OtherPlatform: "Anthropic",
|
||||||
|
}
|
||||||
|
router := setupAccountMixedChannelRouter(adminSvc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"group_ids": []int64{27},
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusConflict, rec.Code)
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
|
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||||
|
_, hasDetails := resp["details"]
|
||||||
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
|
require.False(t, hasDetails)
|
||||||
|
require.False(t, hasRequireConfirmation)
|
||||||
|
}
|
||||||
@@ -10,19 +10,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type stubAdminService struct {
|
type stubAdminService struct {
|
||||||
users []service.User
|
users []service.User
|
||||||
apiKeys []service.APIKey
|
apiKeys []service.APIKey
|
||||||
groups []service.Group
|
groups []service.Group
|
||||||
accounts []service.Account
|
accounts []service.Account
|
||||||
proxies []service.Proxy
|
proxies []service.Proxy
|
||||||
proxyCounts []service.ProxyWithAccountCount
|
proxyCounts []service.ProxyWithAccountCount
|
||||||
redeems []service.RedeemCode
|
redeems []service.RedeemCode
|
||||||
createdAccounts []*service.CreateAccountInput
|
createdAccounts []*service.CreateAccountInput
|
||||||
createdProxies []*service.CreateProxyInput
|
createdProxies []*service.CreateProxyInput
|
||||||
updatedProxyIDs []int64
|
updatedProxyIDs []int64
|
||||||
updatedProxies []*service.UpdateProxyInput
|
updatedProxies []*service.UpdateProxyInput
|
||||||
testedProxyIDs []int64
|
testedProxyIDs []int64
|
||||||
mu sync.Mutex
|
createAccountErr error
|
||||||
|
updateAccountErr error
|
||||||
|
checkMixedErr error
|
||||||
|
lastMixedCheck struct {
|
||||||
|
accountID int64
|
||||||
|
platform string
|
||||||
|
groupIDs []int64
|
||||||
|
}
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newStubAdminService() *stubAdminService {
|
func newStubAdminService() *stubAdminService {
|
||||||
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.createdAccounts = append(s.createdAccounts, input)
|
s.createdAccounts = append(s.createdAccounts, input)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
if s.createAccountErr != nil {
|
||||||
|
return nil, s.createAccountErr
|
||||||
|
}
|
||||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||||
|
if s.updateAccountErr != nil {
|
||||||
|
return nil, s.updateAccountErr
|
||||||
|
}
|
||||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
|
|||||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||||
|
s.lastMixedCheck.accountID = currentAccountID
|
||||||
|
s.lastMixedCheck.platform = currentAccountPlatform
|
||||||
|
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
|
||||||
|
return s.checkMixedErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||||
search = strings.TrimSpace(strings.ToLower(search))
|
search = strings.TrimSpace(strings.ToLower(search))
|
||||||
filtered := make([]service.Proxy, 0, len(s.proxies))
|
filtered := make([]service.Proxy, 0, len(s.proxies))
|
||||||
|
|||||||
@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||||
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
|
if strings.Contains(msg, "OAuth client not configured") ||
|
||||||
|
strings.Contains(msg, "requires your own OAuth Client") ||
|
||||||
|
strings.Contains(msg, "requires a custom OAuth Client") ||
|
||||||
|
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
|
||||||
|
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
|
||||||
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
160
backend/internal/handler/failover_loop.go
Normal file
160
backend/internal/handler/failover_loop.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||||
|
// GatewayService 隐式实现此接口。
|
||||||
|
type TempUnscheduler interface {
|
||||||
|
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FailoverAction 表示 failover 错误处理后的下一步动作
|
||||||
|
type FailoverAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
|
||||||
|
FailoverContinue FailoverAction = iota
|
||||||
|
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
|
||||||
|
FailoverExhausted
|
||||||
|
// FailoverCanceled context 已取消(调用方应直接 return)
|
||||||
|
FailoverCanceled
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||||
|
maxSameAccountRetries = 2
|
||||||
|
// sameAccountRetryDelay 同账号重试间隔
|
||||||
|
sameAccountRetryDelay = 500 * time.Millisecond
|
||||||
|
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||||
|
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
|
||||||
|
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||||
|
singleAccountBackoffDelay = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// FailoverState 跨循环迭代共享的 failover 状态
|
||||||
|
type FailoverState struct {
|
||||||
|
SwitchCount int
|
||||||
|
MaxSwitches int
|
||||||
|
FailedAccountIDs map[int64]struct{}
|
||||||
|
SameAccountRetryCount map[int64]int
|
||||||
|
LastFailoverErr *service.UpstreamFailoverError
|
||||||
|
ForceCacheBilling bool
|
||||||
|
hasBoundSession bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFailoverState 创建 failover 状态
|
||||||
|
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
|
||||||
|
return &FailoverState{
|
||||||
|
MaxSwitches: maxSwitches,
|
||||||
|
FailedAccountIDs: make(map[int64]struct{}),
|
||||||
|
SameAccountRetryCount: make(map[int64]int),
|
||||||
|
hasBoundSession: hasBoundSession,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
|
||||||
|
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
|
||||||
|
func (s *FailoverState) HandleFailoverError(
|
||||||
|
ctx context.Context,
|
||||||
|
gatewayService TempUnscheduler,
|
||||||
|
accountID int64,
|
||||||
|
platform string,
|
||||||
|
failoverErr *service.UpstreamFailoverError,
|
||||||
|
) FailoverAction {
|
||||||
|
s.LastFailoverErr = failoverErr
|
||||||
|
|
||||||
|
// 缓存计费判断
|
||||||
|
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
|
||||||
|
s.ForceCacheBilling = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||||
|
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||||
|
s.SameAccountRetryCount[accountID]++
|
||||||
|
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||||
|
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
||||||
|
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||||
|
return FailoverCanceled
|
||||||
|
}
|
||||||
|
return FailoverContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同账号重试用尽,执行临时封禁
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加入失败列表
|
||||||
|
s.FailedAccountIDs[accountID] = struct{}{}
|
||||||
|
|
||||||
|
// 检查是否耗尽
|
||||||
|
if s.SwitchCount >= s.MaxSwitches {
|
||||||
|
return FailoverExhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
// 递增切换计数
|
||||||
|
s.SwitchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
||||||
|
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
||||||
|
|
||||||
|
// Antigravity 平台换号线性递增延时
|
||||||
|
if platform == service.PlatformAntigravity {
|
||||||
|
delay := time.Duration(s.SwitchCount-1) * time.Second
|
||||||
|
if !sleepWithContext(ctx, delay) {
|
||||||
|
return FailoverCanceled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return FailoverContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
|
||||||
|
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
|
||||||
|
// 清除排除列表、等待退避后重新选号。
|
||||||
|
//
|
||||||
|
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
|
||||||
|
// 返回 FailoverExhausted 时,调用方应返回错误响应。
|
||||||
|
// 返回 FailoverCanceled 时,调用方应直接 return。
|
||||||
|
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
|
||||||
|
if s.LastFailoverErr != nil &&
|
||||||
|
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||||
|
s.SwitchCount <= s.MaxSwitches {
|
||||||
|
|
||||||
|
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
||||||
|
singleAccountBackoffDelay, s.SwitchCount)
|
||||||
|
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||||
|
return FailoverCanceled
|
||||||
|
}
|
||||||
|
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
||||||
|
s.SwitchCount, s.MaxSwitches)
|
||||||
|
s.FailedAccountIDs = make(map[int64]struct{})
|
||||||
|
return FailoverContinue
|
||||||
|
}
|
||||||
|
return FailoverExhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
|
||||||
|
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
|
||||||
|
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||||
|
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
|
||||||
|
func sleepWithContext(ctx context.Context, d time.Duration) bool {
|
||||||
|
if d <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(d):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
732
backend/internal/handler/failover_loop_test.go
Normal file
732
backend/internal/handler/failover_loop_test.go
Normal file
@@ -0,0 +1,732 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
|
||||||
|
type mockTempUnscheduler struct {
|
||||||
|
calls []tempUnscheduleCall
|
||||||
|
}
|
||||||
|
|
||||||
|
type tempUnscheduleCall struct {
|
||||||
|
accountID int64
|
||||||
|
failoverErr *service.UpstreamFailoverError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
|
||||||
|
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helper
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
|
||||||
|
return &service.UpstreamFailoverError{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
RetryableOnSameAccount: retryable,
|
||||||
|
ForceCacheBilling: forceBilling,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// NewFailoverState 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNewFailoverState(t *testing.T) {
|
||||||
|
t.Run("初始化字段正确", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(5, true)
|
||||||
|
require.Equal(t, 5, fs.MaxSwitches)
|
||||||
|
require.Equal(t, 0, fs.SwitchCount)
|
||||||
|
require.NotNil(t, fs.FailedAccountIDs)
|
||||||
|
require.Empty(t, fs.FailedAccountIDs)
|
||||||
|
require.NotNil(t, fs.SameAccountRetryCount)
|
||||||
|
require.Empty(t, fs.SameAccountRetryCount)
|
||||||
|
require.Nil(t, fs.LastFailoverErr)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
require.True(t, fs.hasBoundSession)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无绑定会话", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
require.Equal(t, 3, fs.MaxSwitches)
|
||||||
|
require.False(t, fs.hasBoundSession)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("零最大切换次数", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(0, false)
|
||||||
|
require.Equal(t, 0, fs.MaxSwitches)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// sleepWithContext 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSleepWithContext(t *testing.T) {
|
||||||
|
t.Run("零时长立即返回true", func(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(context.Background(), 0)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("负时长立即返回true", func(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(context.Background(), -1*time.Second)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("正常等待后返回true", func(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
|
||||||
|
require.Less(t, elapsed, 500*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("已取消context立即返回false", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(ctx, 5*time.Second)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("等待期间context取消返回false", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(ctx, 5*time.Second)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Less(t, elapsed, 500*time.Millisecond)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 基本切换流程
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
|
||||||
|
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
require.Equal(t, err, fs.LastFailoverErr)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
|
||||||
|
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
|
||||||
|
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.SwitchCount = 1 // 模拟已切换一次
|
||||||
|
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
|
||||||
|
require.Less(t, elapsed, 3*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("连续切换直到耗尽", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
|
||||||
|
// 第一次切换:0→1
|
||||||
|
err1 := newTestFailoverErr(500, false, false)
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 第二次切换:1→2
|
||||||
|
err2 := newTestFailoverErr(502, false, false)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
|
||||||
|
err3 := newTestFailoverErr(503, false, false)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
|
||||||
|
|
||||||
|
// 验证失败账号列表
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 3)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(300))
|
||||||
|
|
||||||
|
// LastFailoverErr 应为最后一次的错误
|
||||||
|
require.Equal(t, err3, fs.LastFailoverErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(0, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Equal(t, 0, fs.SwitchCount)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_CacheBilling(t *testing.T) {
|
||||||
|
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("两者均为false时不设置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
// 第一次:ForceCacheBilling=true → 设置
|
||||||
|
err1 := newTestFailoverErr(500, false, true)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
|
||||||
|
// 第二次:ForceCacheBilling=false → 仍然保持 true
|
||||||
|
err2 := newTestFailoverErr(502, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||||
|
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
|
||||||
|
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
|
||||||
|
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
|
||||||
|
// 验证等待了 sameAccountRetryDelay (500ms)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||||
|
require.Less(t, elapsed, 2*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 第一次
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
// 第二次
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 第一次、第二次重试
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
|
||||||
|
// 验证 TempUnschedule 被调用
|
||||||
|
require.Len(t, mock.calls, 1)
|
||||||
|
require.Equal(t, int64(100), mock.calls[0].accountID)
|
||||||
|
require.Equal(t, err, mock.calls[0].failoverErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(5, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 账号 100 第一次重试
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
// 账号 200 第一次重试(独立计数)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[200])
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(5, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 耗尽账号 100 的重试
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
// 第三次: 重试耗尽 → 切换
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
|
||||||
|
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — TempUnschedule 调用验证
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||||
|
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Empty(t, mock.calls)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(502, true, false)
|
||||||
|
|
||||||
|
// 耗尽重试
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||||
|
|
||||||
|
require.Len(t, mock.calls, 1)
|
||||||
|
require.Equal(t, int64(42), mock.calls[0].accountID)
|
||||||
|
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
|
||||||
|
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — Context 取消
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
|
||||||
|
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverCanceled, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||||
|
// 重试计数仍应递增
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverCanceled, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — FailedAccountIDs 跟踪
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
|
||||||
|
t.Run("切换时添加到失败列表", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(0, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.NotContains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(5, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — LastFailoverErr 更新
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
|
||||||
|
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
err1 := newTestFailoverErr(500, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.Equal(t, err1, fs.LastFailoverErr)
|
||||||
|
|
||||||
|
err2 := newTestFailoverErr(502, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.Equal(t, err2, fs.LastFailoverErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, err, fs.LastFailoverErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 综合集成场景
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||||
|
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||||
|
|
||||||
|
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||||
|
retryErr := newTestFailoverErr(400, true, false)
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||||
|
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
|
||||||
|
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Len(t, mock.calls, 1)
|
||||||
|
|
||||||
|
// 3. 账号 200 遇到不可重试错误 → 直接切换
|
||||||
|
switchErr := newTestFailoverErr(500, false, false)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 4. 账号 300 遇到不可重试错误 → 再切换
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 3, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
|
||||||
|
// 最终状态验证
|
||||||
|
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
// 第一次切换:delay = 0s
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
|
||||||
|
|
||||||
|
// 第二次切换:delay = 1s
|
||||||
|
start = time.Now()
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||||
|
elapsed = time.Since(start)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
|
||||||
|
|
||||||
|
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
|
||||||
|
start = time.Now()
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
|
||||||
|
elapsed = time.Since(start)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false) // hasBoundSession=false
|
||||||
|
|
||||||
|
// 第一次:ForceCacheBilling=false
|
||||||
|
err1 := newTestFailoverErr(500, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
|
||||||
|
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
|
||||||
|
err2 := newTestFailoverErr(500, false, true)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
|
||||||
|
|
||||||
|
// 第三次:ForceCacheBilling=false,但状态仍保持 true
|
||||||
|
err3 := newTestFailoverErr(500, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "不应重置")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 边界条件
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_EdgeCases(t *testing.T) {
|
||||||
|
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(0, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, true, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[0])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, true, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.SwitchCount = 1
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleSelectionExhausted 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleSelectionExhausted(t *testing.T) {
|
||||||
|
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
// LastFailoverErr 为 nil
|
||||||
|
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非503错误返回Exhausted", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
fs.FailedAccountIDs[100] = struct{}{}
|
||||||
|
fs.SwitchCount = 1
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
|
||||||
|
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
|
||||||
|
require.Less(t, elapsed, 5*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
fs.SwitchCount = 3 // > MaxSwitches(2)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleSelectionExhausted(ctx)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverCanceled, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试
|
||||||
|
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||||
|
|
||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
|
||||||
|
|
||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
switch action {
|
||||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
case FailoverContinue:
|
||||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
c.Request = c.Request.WithContext(ctx)
|
||||||
reqLog.Warn("gateway.single_account_retrying",
|
continue
|
||||||
zap.Int("retry_count", switchCount),
|
case FailoverCanceled:
|
||||||
zap.Int("max_retries", maxAccountSwitches),
|
return
|
||||||
)
|
default: // FailoverExhausted
|
||||||
failedAccountIDs = make(map[int64]struct{})
|
if fs.LastFailoverErr != nil {
|
||||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
} else {
|
||||||
c.Request = c.Request.WithContext(ctx)
|
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
|
||||||
} else {
|
|
||||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if switchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
lastFailoverErr = failoverErr
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
switch action {
|
||||||
forceCacheBilling = true
|
case FailoverContinue:
|
||||||
}
|
|
||||||
|
|
||||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
|
||||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
|
||||||
sameAccountRetryCount[account.ID]++
|
|
||||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
|
||||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
|
||||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
case FailoverExhausted:
|
||||||
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||||
// 同账号重试用尽,执行临时封禁并切换账号
|
return
|
||||||
if failoverErr.RetryableOnSameAccount {
|
case FailoverCanceled:
|
||||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
|
||||||
reqLog.Warn("gateway.upstream_failover_switching",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
)
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed",
|
reqLog.Error("gateway.forward_failed",
|
||||||
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
ForceCacheBilling: forceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
|
||||||
retryWithFallback := false
|
retryWithFallback := false
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
switch action {
|
||||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
case FailoverContinue:
|
||||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
c.Request = c.Request.WithContext(ctx)
|
||||||
reqLog.Warn("gateway.single_account_retrying",
|
continue
|
||||||
zap.Int("retry_count", switchCount),
|
case FailoverCanceled:
|
||||||
zap.Int("max_retries", maxAccountSwitches),
|
return
|
||||||
)
|
default: // FailoverExhausted
|
||||||
failedAccountIDs = make(map[int64]struct{})
|
if fs.LastFailoverErr != nil {
|
||||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
} else {
|
||||||
c.Request = c.Request.WithContext(ctx)
|
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
|
||||||
} else {
|
|
||||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if switchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
lastFailoverErr = failoverErr
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
switch action {
|
||||||
forceCacheBilling = true
|
case FailoverContinue:
|
||||||
}
|
|
||||||
|
|
||||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
|
||||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
|
||||||
sameAccountRetryCount[account.ID]++
|
|
||||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
|
||||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
|
||||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
case FailoverExhausted:
|
||||||
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
|
||||||
// 同账号重试用尽,执行临时封禁并切换账号
|
return
|
||||||
if failoverErr.RetryableOnSameAccount {
|
case FailoverCanceled:
|
||||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
|
||||||
reqLog.Warn("gateway.upstream_failover_switching",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
)
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed",
|
reqLog.Error("gateway.forward_failed",
|
||||||
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Subscription: currentSubscription,
|
Subscription: currentSubscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
ForceCacheBilling: forceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
reqLog.Debug("gateway.request_completed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Bool("fallback_used", fallbackUsed),
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !retryWithFallback {
|
if !retryWithFallback {
|
||||||
@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
|
||||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
|
||||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
|
||||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
|
||||||
maxSameAccountRetries = 2
|
|
||||||
// sameAccountRetryDelay 同账号重试间隔
|
|
||||||
sameAccountRetryDelay = 500 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
|
||||||
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-time.After(sameAccountRetryDelay):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
|
||||||
// 返回 false 表示 context 已取消。
|
|
||||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
|
||||||
delay := time.Duration(switchCount-1) * time.Second
|
|
||||||
if delay <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-time.After(delay):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
|
||||||
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
|
||||||
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
|
||||||
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
|
||||||
// 返回 false 表示 context 已取消。
|
|
||||||
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
|
||||||
// 固定短延时:2s
|
|
||||||
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
|
||||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
|
||||||
const delay = 2 * time.Second
|
|
||||||
|
|
||||||
logger.L().With(
|
|
||||||
zap.String("component", "handler.gateway.failover"),
|
|
||||||
zap.Duration("delay", delay),
|
|
||||||
zap.Int("retry_count", retryCount),
|
|
||||||
).Info("gateway.single_account_backoff_waiting")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-time.After(delay):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||||
statusCode := failoverErr.StatusCode
|
statusCode := failoverErr.StatusCode
|
||||||
responseBody := failoverErr.ResponseBody
|
responseBody := failoverErr.ResponseBody
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// sleepAntigravitySingleAccountBackoff 测试
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
start := time.Now()
|
|
||||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
require.True(t, ok, "should return true when context is not canceled")
|
|
||||||
// 固定延迟 2s
|
|
||||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
|
||||||
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // 立即取消
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
require.False(t, ok, "should return false when context is canceled")
|
|
||||||
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
|
||||||
// 验证不同 retryCount 都使用固定 2s 延迟
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
require.True(t, ok)
|
|
||||||
// 即使 retryCount=5,延迟仍然是固定的 2s
|
|
||||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
|
||||||
require.Less(t, elapsed, 5*time.Second)
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,340 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
|
||||||
|
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
|
||||||
|
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
|
||||||
|
|
||||||
|
type fakeSchedulerCache struct {
|
||||||
|
accounts []*service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||||
|
return f.accounts, true, nil
|
||||||
|
}
|
||||||
|
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
|
||||||
|
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
|
||||||
|
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
|
||||||
|
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
|
||||||
|
|
||||||
|
type fakeGroupRepo struct {
|
||||||
|
group *service.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
|
||||||
|
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
|
||||||
|
return f.group, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
|
||||||
|
return f.group, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
|
||||||
|
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
|
||||||
|
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
|
||||||
|
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
|
||||||
|
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||||
|
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||||
|
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
|
||||||
|
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeConcurrencyCache struct{}
|
||||||
|
|
||||||
|
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
|
||||||
|
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
|
||||||
|
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
|
||||||
|
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
|
||||||
|
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
|
||||||
|
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||||
|
return map[int64]*service.AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||||
|
return map[int64]*service.UserLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||||
|
|
||||||
|
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
schedulerCache := &fakeSchedulerCache{accounts: accounts}
|
||||||
|
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
gwSvc := service.NewGatewayService(
|
||||||
|
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||||
|
&fakeGroupRepo{group: group},
|
||||||
|
nil, // usageLogRepo
|
||||||
|
nil, // userRepo
|
||||||
|
nil, // userSubRepo
|
||||||
|
nil, // userGroupRateRepo
|
||||||
|
nil, // cache (disable sticky)
|
||||||
|
nil, // cfg
|
||||||
|
schedulerSnapshot,
|
||||||
|
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
|
||||||
|
nil, // billingService
|
||||||
|
nil, // rateLimitService
|
||||||
|
nil, // billingCacheService
|
||||||
|
nil, // identityService
|
||||||
|
nil, // httpUpstream
|
||||||
|
nil, // deferredService
|
||||||
|
nil, // claudeTokenProvider
|
||||||
|
nil, // sessionLimitCache
|
||||||
|
nil, // digestStore
|
||||||
|
)
|
||||||
|
|
||||||
|
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||||
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
|
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||||
|
|
||||||
|
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||||
|
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||||
|
|
||||||
|
h := &GatewayHandler{
|
||||||
|
gatewayService: gwSvc,
|
||||||
|
billingCacheService: billingCacheSvc,
|
||||||
|
concurrencyHelper: concurrencyHelper,
|
||||||
|
// 这些字段对本测试不敏感,保持较小即可
|
||||||
|
maxAccountSwitches: 1,
|
||||||
|
maxAccountSwitchesGemini: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
billingCacheSvc.Stop()
|
||||||
|
}
|
||||||
|
return h, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
groupID := int64(2001)
|
||||||
|
accountID := int64(1001)
|
||||||
|
|
||||||
|
group := &service.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Hydrated: true,
|
||||||
|
Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口
|
||||||
|
Status: service.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &service.Account{
|
||||||
|
ID: accountID,
|
||||||
|
Name: "ag-1",
|
||||||
|
Platform: service.PlatformAntigravity,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "tok_xxx",
|
||||||
|
"intercept_warmup_requests": true,
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
|
||||||
|
},
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 1,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||||
|
}`)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 3001,
|
||||||
|
UserID: 4001,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
User: &service.User{
|
||||||
|
ID: 4001,
|
||||||
|
Concurrency: 10,
|
||||||
|
Balance: 100,
|
||||||
|
},
|
||||||
|
Group: group,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||||
|
|
||||||
|
h.Messages(c)
|
||||||
|
|
||||||
|
require.Equal(t, 200, rec.Code)
|
||||||
|
|
||||||
|
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
|
||||||
|
selected, ok := c.Get(opsAccountIDKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, accountID, selected)
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||||
|
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||||
|
|
||||||
|
content, ok := resp["content"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, content, 1)
|
||||||
|
first, ok := content[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "New Conversation", first["text"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
groupID := int64(2002)
|
||||||
|
accountID := int64(1002)
|
||||||
|
|
||||||
|
group := &service.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Hydrated: true,
|
||||||
|
Platform: service.PlatformAntigravity,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &service.Account{
|
||||||
|
ID: accountID,
|
||||||
|
Name: "ag-2",
|
||||||
|
Platform: service.PlatformAntigravity,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "tok_xxx",
|
||||||
|
"intercept_warmup_requests": true,
|
||||||
|
},
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 1,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||||
|
}`)
|
||||||
|
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
|
||||||
|
// - 写入 request.Context(Service读取)
|
||||||
|
// - 写入 gin.Context(Handler快速读取)
|
||||||
|
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
c.Request = req
|
||||||
|
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
|
||||||
|
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 3002,
|
||||||
|
UserID: 4002,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
User: &service.User{
|
||||||
|
ID: 4002,
|
||||||
|
Concurrency: 10,
|
||||||
|
Balance: 100,
|
||||||
|
},
|
||||||
|
Group: group,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||||
|
|
||||||
|
h.Messages(c)
|
||||||
|
|
||||||
|
require.Equal(t, 200, rec.Code)
|
||||||
|
|
||||||
|
selected, ok := c.Get(opsAccountIDKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, accountID, selected)
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||||
|
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||||
|
}
|
||||||
@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||||
cleanedForUnknownBinding := false
|
cleanedForUnknownBinding := false
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
|
||||||
|
|
||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
switch action {
|
||||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
case FailoverContinue:
|
||||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
c.Request = c.Request.WithContext(ctx)
|
||||||
reqLog.Warn("gemini.single_account_retrying",
|
continue
|
||||||
zap.Int("retry_count", switchCount),
|
case FailoverCanceled:
|
||||||
zap.Int("max_retries", maxAccountSwitches),
|
return
|
||||||
)
|
default: // FailoverExhausted
|
||||||
failedAccountIDs = make(map[int64]struct{})
|
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
return
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
// 5) forward (根据平台分流)
|
// 5) forward (根据平台分流)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if switchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||||
@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
switch failoverAction {
|
||||||
forceCacheBilling = true
|
case FailoverContinue:
|
||||||
}
|
continue
|
||||||
if switchCount >= maxAccountSwitches {
|
case FailoverExhausted:
|
||||||
lastFailoverErr = failoverErr
|
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
return
|
||||||
|
case FailoverCanceled:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverErr = failoverErr
|
|
||||||
switchCount++
|
|
||||||
reqLog.Warn("gemini.upstream_failover_switching",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
)
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// ForwardNative already wrote the response
|
// ForwardNative already wrote the response
|
||||||
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: forceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
reqLog.Debug("gemini.request_completed",
|
reqLog.Debug("gemini.request_completed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("switch_count", switchCount),
|
zap.Int("switch_count", fs.SwitchCount),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
func TestClient_ExchangeCode_成功(t *testing.T) {
|
func TestClient_ExchangeCode_成功(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// 验证请求方法
|
// 验证请求方法
|
||||||
@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
|
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = ""
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
client := NewClient("")
|
client := NewClient("")
|
||||||
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
|
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
|
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
func TestClient_RefreshToken_MockServer(t *testing.T) {
|
func TestClient_RefreshToken_MockServer(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
|
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = ""
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
client := NewClient("")
|
client := NewClient("")
|
||||||
_, err := client.RefreshToken(context.Background(), "refresh-tok")
|
_, err := client.RefreshToken(context.Background(), "refresh-tok")
|
||||||
@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client {
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
|
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
|
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
|
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
|
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
time.Sleep(5 * time.Second) // 模拟慢响应
|
time.Sleep(5 * time.Second) // 模拟慢响应
|
||||||
@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
|
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
|
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
|
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
|
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = "test-secret"
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
|
|||||||
@@ -23,11 +23,9 @@ const (
|
|||||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||||
|
|
||||||
// Antigravity OAuth 客户端凭证
|
// Antigravity OAuth 客户端凭证
|
||||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
ClientSecret = ""
|
|
||||||
|
|
||||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||||
// 出于安全原因,该值不得硬编码入库。
|
|
||||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
// 固定的 redirect_uri(用户需手动复制 code)
|
// 固定的 redirect_uri(用户需手动复制 code)
|
||||||
@@ -51,14 +49,21 @@ const (
|
|||||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4
|
||||||
var defaultUserAgentVersion = "1.84.2"
|
var defaultUserAgentVersion = "1.18.4"
|
||||||
|
|
||||||
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
|
var defaultClientSecret = "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"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// 从环境变量读取版本号,未设置则使用默认值
|
// 从环境变量读取版本号,未设置则使用默认值
|
||||||
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
|
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
|
||||||
defaultUserAgentVersion = version
|
defaultUserAgentVersion = version
|
||||||
}
|
}
|
||||||
|
// 从环境变量读取 client_secret,未设置则使用默认值
|
||||||
|
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
|
||||||
|
defaultClientSecret = secret
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserAgent 返回当前配置的 User-Agent
|
// GetUserAgent 返回当前配置的 User-Agent
|
||||||
@@ -67,14 +72,9 @@ func GetUserAgent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getClientSecret() (string, error) {
|
func getClientSecret() (string, error) {
|
||||||
if v := strings.TrimSpace(ClientSecret); v != "" {
|
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
|
|
||||||
if vv := strings.TrimSpace(v); vv != "" {
|
|
||||||
return vv, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
|
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,8 +18,14 @@ import (
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||||
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = ""
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||||
|
|
||||||
|
// 需要重新触发 init 逻辑:手动从环境变量读取
|
||||||
|
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
|
||||||
|
|
||||||
secret, err := getClientSecret()
|
secret, err := getClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||||
@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = ""
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
_, err := getClientSecret()
|
_, err := getClientSecret()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("环境变量为空时应返回错误")
|
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||||
@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||||
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
|
old := defaultClientSecret
|
||||||
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
|
defaultClientSecret = ""
|
||||||
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
// 明确设置再取消,确保环境变量不存在
|
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
|
||||||
|
|
||||||
_, err := getClientSecret()
|
_, err := getClientSecret()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("环境变量未设置时应返回错误")
|
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = " "
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
_, err := getClientSecret()
|
_, err := getClientSecret()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("环境变量仅含空格时应返回错误")
|
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||||
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
|
old := defaultClientSecret
|
||||||
|
defaultClientSecret = " valid-secret "
|
||||||
|
t.Cleanup(func() { defaultClientSecret = old })
|
||||||
|
|
||||||
secret, err := getClientSecret()
|
secret, err := getClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||||
}
|
}
|
||||||
if ClientSecret != "" {
|
secret, err := getClientSecret()
|
||||||
t.Error("ClientSecret 应为空字符串")
|
if err != nil {
|
||||||
|
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||||
|
}
|
||||||
|
if secret != "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" {
|
||||||
|
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||||
}
|
}
|
||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||||
}
|
}
|
||||||
if GetUserAgent() != "antigravity/1.84.2 windows/amd64" {
|
if GetUserAgent() != "antigravity/1.18.4 windows/amd64" {
|
||||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||||
}
|
}
|
||||||
if SessionTTL != 30*time.Minute {
|
if SessionTTL != 30*time.Minute {
|
||||||
|
|||||||
@@ -206,6 +206,7 @@ type modelInfo struct {
|
|||||||
var modelInfoMap = map[string]modelInfo{
|
var modelInfoMap = map[string]modelInfo{
|
||||||
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
||||||
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
||||||
|
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
|
||||||
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
|
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
|
||||||
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
|
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,10 +38,8 @@ const (
|
|||||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||||
// restrict which scopes are allowed for this client.
|
// restrict which scopes are allowed for this client.
|
||||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
// GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
|
GeminiCLIOAuthClientSecret = "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"
|
||||||
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
|
|
||||||
GeminiCLIOAuthClientSecret = ""
|
|
||||||
|
|
||||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||||
|
|||||||
@@ -408,11 +408,10 @@ func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
|
||||||
// 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错
|
|
||||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||||
|
|
||||||
_, err := BuildAuthorizationURL(
|
authURL, err := BuildAuthorizationURL(
|
||||||
OAuthConfig{},
|
OAuthConfig{},
|
||||||
"test-state",
|
"test-state",
|
||||||
"test-challenge",
|
"test-challenge",
|
||||||
@@ -420,8 +419,11 @@ func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
|||||||
"",
|
"",
|
||||||
"code_assist",
|
"code_assist",
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误")
|
t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) {
|
||||||
|
t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -685,15 +687,17 @@ func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||||
// 不设置环境变量且不提供凭据,应该报错
|
|
||||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||||
|
|
||||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||||
if err == nil {
|
if err != nil {
|
||||||
t.Error("没有内置 secret 且未提供凭据时应该报错")
|
t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err)
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||||
t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err)
|
t.Error("ClientSecret 不应为空")
|
||||||
|
}
|
||||||
|
if cfg.ClientID != GeminiCLIOAuthClientID {
|
||||||
|
t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
accounts.GET("", h.Admin.Account.List)
|
accounts.GET("", h.Admin.Account.List)
|
||||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||||
accounts.POST("", h.Admin.Account.Create)
|
accounts.POST("", h.Admin.Account.Create)
|
||||||
|
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
||||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||||
|
|||||||
@@ -372,6 +372,13 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
|
if a.Platform == domain.PlatformAntigravity {
|
||||||
|
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||||
|
"gemini-3-flash",
|
||||||
|
"gemini-3.1-pro-high",
|
||||||
|
"gemini-3.1-pro-low",
|
||||||
|
})
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -382,6 +389,27 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
|
||||||
|
if mapping == nil || model == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := mapping[model]; exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for pattern := range mapping {
|
||||||
|
if matchWildcard(pattern, model) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mapping[model] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []string) {
|
||||||
|
for _, model := range models {
|
||||||
|
ensureAntigravityDefaultPassthrough(mapping, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
||||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
// 如果未配置 mapping,返回 true(允许所有模型)
|
||||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||||
|
|||||||
66
backend/internal/service/account_intercept_warmup_test.go
Normal file
66
backend/internal/service/account_intercept_warmup_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
credentials map[string]any
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil credentials",
|
||||||
|
credentials: nil,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
credentials: map[string]any{},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field not present",
|
||||||
|
credentials: map[string]any{"access_token": "tok"},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field is true",
|
||||||
|
credentials: map[string]any{"intercept_warmup_requests": true},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field is false",
|
||||||
|
credentials: map[string]any{"intercept_warmup_requests": false},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field is string true",
|
||||||
|
credentials: map[string]any{"intercept_warmup_requests": "true"},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field is int 1",
|
||||||
|
credentials: map[string]any{"intercept_warmup_requests": 1},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "field is nil",
|
||||||
|
credentials: map[string]any{"intercept_warmup_requests": nil},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Account{Credentials: tt.credentials}
|
||||||
|
result := a.IsInterceptWarmupEnabled()
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -217,12 +218,20 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account.Platform == PlatformGemini {
|
if account.Platform == PlatformGemini {
|
||||||
return s.getGeminiUsage(ctx, account)
|
usage, err := s.getGeminiUsage(ctx, account)
|
||||||
|
if err == nil {
|
||||||
|
s.tryClearRecoverableAccountError(ctx, account)
|
||||||
|
}
|
||||||
|
return usage, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
return s.getAntigravityUsage(ctx, account)
|
usage, err := s.getAntigravityUsage(ctx, account)
|
||||||
|
if err == nil {
|
||||||
|
s.tryClearRecoverableAccountError(ctx, account)
|
||||||
|
}
|
||||||
|
return usage, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||||
@@ -256,6 +265,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
|||||||
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
||||||
s.addWindowStats(ctx, account, usage)
|
s.addWindowStats(ctx, account, usage)
|
||||||
|
|
||||||
|
s.tryClearRecoverableAccountError(ctx, account)
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -486,6 +496,32 @@ func parseTime(s string) (time.Time, error) {
|
|||||||
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
|
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountUsageService) tryClearRecoverableAccountError(ctx context.Context, account *Account) {
|
||||||
|
if account == nil || account.Status != StatusError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(account.ErrorMessage))
|
||||||
|
if msg == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(msg, "token refresh failed") &&
|
||||||
|
!strings.Contains(msg, "invalid_client") &&
|
||||||
|
!strings.Contains(msg, "missing_project_id") &&
|
||||||
|
!strings.Contains(msg, "unauthenticated") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
|
||||||
|
log.Printf("[usage] failed to clear recoverable account error for account %d: %v", account.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Status = StatusActive
|
||||||
|
account.ErrorMessage = ""
|
||||||
|
}
|
||||||
|
|
||||||
// buildUsageInfo 构建UsageInfo
|
// buildUsageInfo 构建UsageInfo
|
||||||
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
|
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
|
||||||
info := &UsageInfo{
|
info := &UsageInfo{
|
||||||
|
|||||||
@@ -267,3 +267,50 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3-pro-high": "gemini-3.1-pro-high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if mapping["gemini-3-flash"] != "gemini-3-flash" {
|
||||||
|
t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"])
|
||||||
|
}
|
||||||
|
if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" {
|
||||||
|
t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"])
|
||||||
|
}
|
||||||
|
if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" {
|
||||||
|
t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3*": "gemini-3.1-pro-high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapping := account.GetModelMapping()
|
||||||
|
if _, exists := mapping["gemini-3-flash"]; exists {
|
||||||
|
t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists")
|
||||||
|
}
|
||||||
|
if _, exists := mapping["gemini-3.1-pro-high"]; exists {
|
||||||
|
t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists")
|
||||||
|
}
|
||||||
|
if _, exists := mapping["gemini-3.1-pro-low"]; exists {
|
||||||
|
t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists")
|
||||||
|
}
|
||||||
|
if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" {
|
||||||
|
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ type AdminService interface {
|
|||||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||||
|
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||||
|
|
||||||
// Proxy management
|
// Proxy management
|
||||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
||||||
@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
|
||||||
|
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||||
|
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||||
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -87,7 +87,6 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
|
||||||
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||||
)
|
)
|
||||||
@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||||
|
billingModel := mappedModel
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
ForceCacheBilling: switchErr.IsStickySession,
|
ForceCacheBilling: switchErr.IsStickySession,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||||
|
if c.Request.Context().Err() != nil {
|
||||||
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
|
||||||
|
}
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||||
}
|
}
|
||||||
resp := result.resp
|
resp := result.resp
|
||||||
@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel, // 使用原始模型用于计费和日志
|
Model: billingModel, // 使用映射模型用于计费和日志
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
if mappedModel == "" {
|
if mappedModel == "" {
|
||||||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||||||
}
|
}
|
||||||
|
billingModel := mappedModel
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
ForceCacheBilling: switchErr.IsStickySession,
|
ForceCacheBilling: switchErr.IsStickySession,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||||
|
if c.Request.Context().Err() != nil {
|
||||||
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
|
||||||
|
}
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||||
}
|
}
|
||||||
resp := result.resp
|
resp := result.resp
|
||||||
@@ -2197,7 +2206,7 @@ handleSuccess:
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel,
|
Model: billingModel,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
|||||||
defaultDur := s.getDefaultRateLimitDuration()
|
defaultDur := s.getDefaultRateLimitDuration()
|
||||||
|
|
||||||
// 尝试解析模型 key 并设置模型级限流
|
// 尝试解析模型 key 并设置模型级限流
|
||||||
modelKey := resolveAntigravityModelKey(requestedModel)
|
//
|
||||||
|
// 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6),
|
||||||
|
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
|
||||||
|
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
|
||||||
|
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||||
|
if strings.TrimSpace(modelKey) == "" {
|
||||||
|
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
|
||||||
|
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
|
||||||
|
modelKey = resolveAntigravityModelKey(requestedModel)
|
||||||
|
}
|
||||||
if modelKey != "" {
|
if modelKey != "" {
|
||||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||||||
@@ -3881,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
return nil, fmt.Errorf("missing model")
|
return nil, fmt.Errorf("missing model")
|
||||||
}
|
}
|
||||||
originalModel := claudeReq.Model
|
originalModel := claudeReq.Model
|
||||||
billingModel := originalModel
|
|
||||||
|
|
||||||
// 构建上游请求 URL
|
// 构建上游请求 URL
|
||||||
upstreamURL := baseURL + "/v1/messages"
|
upstreamURL := baseURL + "/v1/messages"
|
||||||
@@ -3934,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
_, _ = c.Writer.Write(respBody)
|
_, _ = c.Writer.Write(respBody)
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
Model: billingModel,
|
Model: originalModel,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3975,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
Model: billingModel,
|
Model: originalModel,
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: duration,
|
Duration: duration,
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
|||||||
@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
|||||||
return s.resp, s.err
|
return s.resp, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type antigravitySettingRepoStub struct{}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
return "", ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
panic("unexpected GetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unexpected SetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
writer := httptest.NewRecorder()
|
writer := httptest.NewRecorder()
|
||||||
@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
svc := &AntigravityGatewayService{
|
svc := &AntigravityGatewayService{
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
|||||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
|
||||||
|
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
|
||||||
|
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
"max_tokens": 16,
|
||||||
|
"stream": true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
|
}
|
||||||
|
|
||||||
|
const mappedModel = "gemini-3-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 5,
|
||||||
|
Name: "acc-forward-billing",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-sonnet-4-5": mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, mappedModel, result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||||
|
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
|
||||||
|
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
|
}
|
||||||
|
|
||||||
|
const mappedModel = "gemini-3-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 6,
|
||||||
|
Name: "acc-gemini-billing",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-2.5-flash": mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, mappedModel, result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||||
|
|||||||
@@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// 3. 默认映射中的透传(映射到自己)
|
// 3. 默认映射中的透传(映射到自己)
|
||||||
|
{
|
||||||
|
name: "默认映射透传 - claude-sonnet-4-6",
|
||||||
|
requestedModel: "claude-sonnet-4-6",
|
||||||
|
accountMapping: nil,
|
||||||
|
expected: "claude-sonnet-4-6",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "默认映射透传 - claude-sonnet-4-5",
|
name: "默认映射透传 - claude-sonnet-4-5",
|
||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
|||||||
@@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
|||||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景
|
||||||
|
// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking)
|
||||||
|
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
|
||||||
|
|
||||||
|
body := buildGeminiRateLimitBody("5s")
|
||||||
|
|
||||||
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
|
||||||
|
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||||
|
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
|
||||||
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||||
|
|||||||
@@ -133,6 +133,18 @@ func (s *BillingService) initFallbackPricing() {
|
|||||||
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
||||||
SupportsCacheBreakdown: false,
|
SupportsCacheBreakdown: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude 4.6 Opus (与4.5同价)
|
||||||
|
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
|
||||||
|
|
||||||
|
// Gemini 3.1 Pro
|
||||||
|
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||||
|
InputPricePerToken: 2e-6, // $2 per MTok
|
||||||
|
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||||
|
CacheCreationPricePerToken: 2e-6, // $2 per MTok
|
||||||
|
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
|
||||||
|
SupportsCacheBreakdown: false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getFallbackPricing 根据模型系列获取回退价格
|
// getFallbackPricing 根据模型系列获取回退价格
|
||||||
@@ -141,6 +153,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
|||||||
|
|
||||||
// 按模型系列匹配
|
// 按模型系列匹配
|
||||||
if strings.Contains(modelLower, "opus") {
|
if strings.Contains(modelLower, "opus") {
|
||||||
|
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
|
||||||
|
return s.fallbackPrices["claude-opus-4.6"]
|
||||||
|
}
|
||||||
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
||||||
return s.fallbackPrices["claude-opus-4.5"]
|
return s.fallbackPrices["claude-opus-4.5"]
|
||||||
}
|
}
|
||||||
@@ -158,6 +173,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
|||||||
}
|
}
|
||||||
return s.fallbackPrices["claude-3-haiku"]
|
return s.fallbackPrices["claude-3-haiku"]
|
||||||
}
|
}
|
||||||
|
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
|
||||||
|
return s.fallbackPrices["gemini-3.1-pro"]
|
||||||
|
}
|
||||||
|
|
||||||
// 默认使用Sonnet价格
|
// 默认使用Sonnet价格
|
||||||
return s.fallbackPrices["claude-sonnet-4"]
|
return s.fallbackPrices["claude-sonnet-4"]
|
||||||
|
|||||||
@@ -895,6 +895,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t
|
|||||||
require.Equal(t, int64(2), acc.ID)
|
require.Equal(t, int64(2), acc.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Priority: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Priority: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号")
|
||||||
|
|
||||||
|
acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, acc)
|
||||||
|
require.Contains(t, err.Error(), "supporting model")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
|
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
groupID := int64(50)
|
groupID := int64(50)
|
||||||
@@ -1070,6 +1119,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
|||||||
model: "claude-3-5-sonnet-20241022",
|
model: "claude-3-5-sonnet-20241022",
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini平台-无映射配置-支持所有模型",
|
||||||
|
account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey},
|
||||||
|
model: "gemini-2.5-flash",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||||
|
account: &Account{
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: "gemini-2.5-flash",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini平台-有映射配置-支持配置的模型",
|
||||||
|
account: &Account{
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: "gemini-2.5-pro",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -2825,10 +2825,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||||
}
|
}
|
||||||
// Gemini API Key 账户直接透传,由上游判断模型是否支持
|
|
||||||
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// 其他平台使用账户的模型支持检查
|
// 其他平台使用账户的模型支持检查
|
||||||
return account.IsModelSupported(requestedModel)
|
return account.IsModelSupported(requestedModel)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,12 +107,12 @@ func TestIsModelRateLimited(t *testing.T) {
|
|||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3.1-pro-high",
|
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high",
|
||||||
account: &Account{
|
account: &Account{
|
||||||
Platform: PlatformAntigravity,
|
Platform: PlatformAntigravity,
|
||||||
Extra: map[string]any{
|
Extra: map[string]any{
|
||||||
modelRateLimitsKey: map[string]any{
|
modelRateLimitsKey: map[string]any{
|
||||||
"gemini-3.1-pro-high": map[string]any{
|
"gemini-3-pro-high": map[string]any{
|
||||||
"rate_limit_reset_at": future,
|
"rate_limit_reset_at": future,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
42
backend/migrations/058_add_sonnet46_to_model_mapping.sql
Normal file
42
backend/migrations/058_add_sonnet46_to_model_mapping.sql
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts
|
||||||
|
--
|
||||||
|
-- Background:
|
||||||
|
-- Antigravity now supports claude-sonnet-4-6
|
||||||
|
--
|
||||||
|
-- Strategy:
|
||||||
|
-- Directly overwrite the entire model_mapping with updated mappings
|
||||||
|
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
|
||||||
|
|
||||||
|
UPDATE accounts
|
||||||
|
SET credentials = jsonb_set(
|
||||||
|
credentials,
|
||||||
|
'{model_mapping}',
|
||||||
|
'{
|
||||||
|
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||||
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||||
|
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||||
|
"gemini-3-flash-preview": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||||
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
|
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||||
|
}'::jsonb
|
||||||
|
)
|
||||||
|
WHERE platform = 'antigravity'
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
AND credentials->'model_mapping' IS NOT NULL;
|
||||||
45
backend/migrations/059_add_gemini31_pro_to_model_mapping.sql
Normal file
45
backend/migrations/059_add_gemini31_pro_to_model_mapping.sql
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping
|
||||||
|
--
|
||||||
|
-- Background:
|
||||||
|
-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low
|
||||||
|
--
|
||||||
|
-- Strategy:
|
||||||
|
-- Directly overwrite the entire model_mapping with updated mappings
|
||||||
|
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
|
||||||
|
|
||||||
|
UPDATE accounts
|
||||||
|
SET credentials = jsonb_set(
|
||||||
|
credentials,
|
||||||
|
'{model_mapping}',
|
||||||
|
'{
|
||||||
|
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||||
|
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||||
|
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||||
|
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||||
|
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
|
"gemini-3-flash": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||||
|
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||||
|
"gemini-3-flash-preview": "gemini-3-flash",
|
||||||
|
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||||
|
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||||
|
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||||
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
|
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||||
|
}'::jsonb
|
||||||
|
)
|
||||||
|
WHERE platform = 'antigravity'
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
AND credentials->'model_mapping' IS NOT NULL;
|
||||||
@@ -15,7 +15,9 @@ import type {
|
|||||||
AccountUsageStatsResponse,
|
AccountUsageStatsResponse,
|
||||||
TempUnschedulableStatus,
|
TempUnschedulableStatus,
|
||||||
AdminDataPayload,
|
AdminDataPayload,
|
||||||
AdminDataImportResult
|
AdminDataImportResult,
|
||||||
|
CheckMixedChannelRequest,
|
||||||
|
CheckMixedChannelResponse
|
||||||
} from '@/types'
|
} from '@/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -133,6 +135,16 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check mixed-channel risk for account-group binding.
|
||||||
|
*/
|
||||||
|
export async function checkMixedChannelRisk(
|
||||||
|
payload: CheckMixedChannelRequest
|
||||||
|
): Promise<CheckMixedChannelResponse> {
|
||||||
|
const { data } = await apiClient.post<CheckMixedChannelResponse>('/admin/accounts/check-mixed-channel', payload)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete account
|
* Delete account
|
||||||
* @param id - Account ID
|
* @param id - Account ID
|
||||||
@@ -535,6 +547,7 @@ export const accountsAPI = {
|
|||||||
getById,
|
getById,
|
||||||
create,
|
create,
|
||||||
update,
|
update,
|
||||||
|
checkMixedChannelRisk,
|
||||||
delete: deleteAccount,
|
delete: deleteAccount,
|
||||||
toggleStatus,
|
toggleStatus,
|
||||||
testAccount,
|
testAccount,
|
||||||
|
|||||||
@@ -77,13 +77,23 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
||||||
<template v-if="activeModelRateLimits.length > 0">
|
<div
|
||||||
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
|
v-if="activeModelRateLimits.length > 0"
|
||||||
|
:class="[
|
||||||
|
activeModelRateLimits.length <= 4
|
||||||
|
? 'flex flex-col gap-1'
|
||||||
|
: activeModelRateLimits.length <= 8
|
||||||
|
? 'columns-2 gap-x-2'
|
||||||
|
: 'columns-3 gap-x-2'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative mb-1 break-inside-avoid">
|
||||||
<span
|
<span
|
||||||
class="inline-flex items-center gap-1 rounded bg-purple-100 px-1.5 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/30 dark:text-purple-400"
|
class="inline-flex items-center gap-1 rounded bg-purple-100 px-1.5 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/30 dark:text-purple-400"
|
||||||
>
|
>
|
||||||
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
||||||
{{ formatScopeName(item.model) }}
|
{{ formatScopeName(item.model) }}
|
||||||
|
<span class="text-[10px] opacity-70">{{ formatModelResetTime(item.reset_at) }}</span>
|
||||||
</span>
|
</span>
|
||||||
<!-- Tooltip -->
|
<!-- Tooltip -->
|
||||||
<div
|
<div
|
||||||
@@ -95,7 +105,7 @@
|
|||||||
></div>
|
></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</div>
|
||||||
|
|
||||||
<!-- Overload Indicator (529) -->
|
<!-- Overload Indicator (529) -->
|
||||||
<div v-if="isOverloaded" class="group relative">
|
<div v-if="isOverloaded" class="group relative">
|
||||||
@@ -154,17 +164,50 @@ const activeModelRateLimits = computed(() => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const formatScopeName = (scope: string): string => {
|
const formatScopeName = (scope: string): string => {
|
||||||
const names: Record<string, string> = {
|
const aliases: Record<string, string> = {
|
||||||
|
// Claude 系列
|
||||||
|
'claude-opus-4-6-thinking': 'COpus46',
|
||||||
|
'claude-sonnet-4-6': 'CSon46',
|
||||||
|
'claude-sonnet-4-5': 'CSon45',
|
||||||
|
'claude-sonnet-4-5-thinking': 'CSon45T',
|
||||||
|
// Gemini 2.5 系列
|
||||||
|
'gemini-2.5-flash': 'G25F',
|
||||||
|
'gemini-2.5-flash-lite': 'G25FL',
|
||||||
|
'gemini-2.5-flash-thinking': 'G25FT',
|
||||||
|
'gemini-2.5-pro': 'G25P',
|
||||||
|
// Gemini 3 系列
|
||||||
|
'gemini-3-flash': 'G3F',
|
||||||
|
'gemini-3.1-pro-high': 'G3PH',
|
||||||
|
'gemini-3.1-pro-low': 'G3PL',
|
||||||
|
'gemini-3-pro-image': 'G3PI',
|
||||||
|
// 其他
|
||||||
|
'gpt-oss-120b-medium': 'GPT120',
|
||||||
|
'tab_flash_lite_preview': 'TabFL',
|
||||||
|
// 旧版 scope 别名(兼容)
|
||||||
claude: 'Claude',
|
claude: 'Claude',
|
||||||
claude_sonnet: 'Claude Sonnet',
|
claude_sonnet: 'CSon',
|
||||||
claude_opus: 'Claude Opus',
|
claude_opus: 'COpus',
|
||||||
claude_haiku: 'Claude Haiku',
|
claude_haiku: 'CHaiku',
|
||||||
gemini_text: 'Gemini',
|
gemini_text: 'Gemini',
|
||||||
gemini_image: 'Image',
|
gemini_image: 'GImg',
|
||||||
gemini_flash: 'Gemini Flash',
|
gemini_flash: 'GFlash',
|
||||||
gemini_pro: 'Gemini Pro'
|
gemini_pro: 'GPro',
|
||||||
}
|
}
|
||||||
return names[scope] || scope
|
return aliases[scope] || scope
|
||||||
|
}
|
||||||
|
|
||||||
|
const formatModelResetTime = (resetAt: string): string => {
|
||||||
|
const date = new Date(resetAt)
|
||||||
|
const now = new Date()
|
||||||
|
const diffMs = date.getTime() - now.getTime()
|
||||||
|
if (diffMs <= 0) return ''
|
||||||
|
const totalSecs = Math.floor(diffMs / 1000)
|
||||||
|
const h = Math.floor(totalSecs / 3600)
|
||||||
|
const m = Math.floor((totalSecs % 3600) / 60)
|
||||||
|
const s = totalSecs % 60
|
||||||
|
if (h > 0) return `${h}h${m}m`
|
||||||
|
if (m > 0) return `${m}m${s}s`
|
||||||
|
return `${s}s`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Computed: is overloaded (529)
|
// Computed: is overloaded (529)
|
||||||
|
|||||||
@@ -172,12 +172,12 @@
|
|||||||
color="purple"
|
color="purple"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- Claude 4.5 -->
|
<!-- Claude -->
|
||||||
<UsageProgressBar
|
<UsageProgressBar
|
||||||
v-if="antigravityClaude45UsageFromAPI !== null"
|
v-if="antigravityClaudeUsageFromAPI !== null"
|
||||||
:label="t('admin.accounts.usageWindow.claude45')"
|
:label="t('admin.accounts.usageWindow.claude')"
|
||||||
:utilization="antigravityClaude45UsageFromAPI.utilization"
|
:utilization="antigravityClaudeUsageFromAPI.utilization"
|
||||||
:resets-at="antigravityClaude45UsageFromAPI.resetTime"
|
:resets-at="antigravityClaudeUsageFromAPI.resetTime"
|
||||||
color="amber"
|
color="amber"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -400,9 +400,12 @@ const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(
|
|||||||
// Gemini 3 Image from API
|
// Gemini 3 Image from API
|
||||||
const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image']))
|
const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image']))
|
||||||
|
|
||||||
// Claude 4.5 from API
|
// Claude from API (all Claude model variants)
|
||||||
const antigravityClaude45UsageFromAPI = computed(() =>
|
const antigravityClaudeUsageFromAPI = computed(() =>
|
||||||
getAntigravityUsageFromAPI(['claude-sonnet-4-5', 'claude-opus-4-5-thinking'])
|
getAntigravityUsageFromAPI([
|
||||||
|
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
|
||||||
|
'claude-sonnet-4-6', 'claude-opus-4-6-thinking',
|
||||||
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
// Antigravity 账户类型(从 load_code_assist 响应中提取)
|
// Antigravity 账户类型(从 load_code_assist 响应中提取)
|
||||||
|
|||||||
@@ -209,7 +209,7 @@
|
|||||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
<div
|
<div
|
||||||
v-for="(mapping, index) in modelMappings"
|
v-for="(mapping, index) in modelMappings"
|
||||||
:key="getModelMappingKey(mapping)"
|
:key="index"
|
||||||
class="flex items-center gap-2"
|
class="flex items-center gap-2"
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
@@ -654,7 +654,7 @@ import Select from '@/components/common/Select.vue'
|
|||||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
show: boolean
|
show: boolean
|
||||||
@@ -696,7 +696,6 @@ const baseUrl = ref('')
|
|||||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||||
const allowedModels = ref<string[]>([])
|
const allowedModels = ref<string[]>([])
|
||||||
const modelMappings = ref<ModelMapping[]>([])
|
const modelMappings = ref<ModelMapping[]>([])
|
||||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
|
|
||||||
const selectedErrorCodes = ref<number[]>([])
|
const selectedErrorCodes = ref<number[]>([])
|
||||||
const customErrorCodeInput = ref<number | null>(null)
|
const customErrorCodeInput = ref<number | null>(null)
|
||||||
const interceptWarmupRequests = ref(false)
|
const interceptWarmupRequests = ref(false)
|
||||||
@@ -707,7 +706,7 @@ const rateMultiplier = ref(1)
|
|||||||
const status = ref<'active' | 'inactive'>('active')
|
const status = ref<'active' | 'inactive'>('active')
|
||||||
const groupIds = ref<number[]>([])
|
const groupIds = ref<number[]>([])
|
||||||
|
|
||||||
// All models list (combined Anthropic + OpenAI)
|
// All models list (combined Anthropic + OpenAI + Gemini)
|
||||||
const allModels = [
|
const allModels = [
|
||||||
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
||||||
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
|
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
|
||||||
@@ -719,17 +718,21 @@ const allModels = [
|
|||||||
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
|
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
|
||||||
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
|
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
|
||||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
||||||
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
|
|
||||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' }
|
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
||||||
|
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
|
||||||
|
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
|
||||||
|
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
|
||||||
|
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
|
||||||
|
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
|
||||||
]
|
]
|
||||||
|
|
||||||
// Preset mappings (combined Anthropic + OpenAI)
|
// Preset mappings (combined Anthropic + OpenAI + Gemini)
|
||||||
const presetMappings = [
|
const presetMappings = [
|
||||||
{
|
{
|
||||||
label: 'Sonnet 4',
|
label: 'Sonnet 4',
|
||||||
@@ -765,18 +768,37 @@ const presetMappings = [
|
|||||||
color:
|
color:
|
||||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: 'Sonnet4→4.6',
|
||||||
|
from: 'claude-sonnet-4-20250514',
|
||||||
|
to: 'claude-sonnet-4-6',
|
||||||
|
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Sonnet4.5→4.6',
|
||||||
|
from: 'claude-sonnet-4-5-20250929',
|
||||||
|
to: 'claude-sonnet-4-6',
|
||||||
|
color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Sonnet3.5→4.6',
|
||||||
|
from: 'claude-3-5-sonnet-20241022',
|
||||||
|
to: 'claude-sonnet-4-6',
|
||||||
|
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Opus4.5→4.6',
|
||||||
|
from: 'claude-opus-4-5-20251101',
|
||||||
|
to: 'claude-opus-4-6-thinking',
|
||||||
|
color:
|
||||||
|
'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: 'Opus->Sonnet',
|
label: 'Opus->Sonnet',
|
||||||
from: 'claude-opus-4-5-20251101',
|
from: 'claude-opus-4-5-20251101',
|
||||||
to: 'claude-sonnet-4-5-20250929',
|
to: 'claude-sonnet-4-5-20250929',
|
||||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||||
},
|
},
|
||||||
{
|
|
||||||
label: 'GPT-5.3 Codex Spark',
|
|
||||||
from: 'gpt-5.3-codex-spark',
|
|
||||||
to: 'gpt-5.3-codex-spark',
|
|
||||||
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
label: 'GPT-5.2',
|
label: 'GPT-5.2',
|
||||||
from: 'gpt-5.2-2025-12-11',
|
from: 'gpt-5.2-2025-12-11',
|
||||||
@@ -794,6 +816,36 @@ const presetMappings = [
|
|||||||
from: 'gpt-5.1-codex-max',
|
from: 'gpt-5.1-codex-max',
|
||||||
to: 'gpt-5.1-codex',
|
to: 'gpt-5.1-codex',
|
||||||
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400'
|
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '3-Pro-Preview→3.1-Pro-High',
|
||||||
|
from: 'gemini-3-pro-preview',
|
||||||
|
to: 'gemini-3.1-pro-high',
|
||||||
|
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '3-Pro-High→3.1-Pro-High',
|
||||||
|
from: 'gemini-3-pro-high',
|
||||||
|
to: 'gemini-3.1-pro-high',
|
||||||
|
color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '3-Pro-Low→3.1-Pro-Low',
|
||||||
|
from: 'gemini-3-pro-low',
|
||||||
|
to: 'gemini-3.1-pro-low',
|
||||||
|
color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '3-Flash透传',
|
||||||
|
from: 'gemini-3-flash',
|
||||||
|
to: 'gemini-3-flash',
|
||||||
|
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '2.5-Flash-Lite透传',
|
||||||
|
from: 'gemini-2.5-flash-lite',
|
||||||
|
to: 'gemini-2.5-flash-lite',
|
||||||
|
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -883,23 +935,11 @@ const removeErrorCode = (code: number) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const buildModelMappingObject = (): Record<string, string> | null => {
|
const buildModelMappingObject = (): Record<string, string> | null => {
|
||||||
const mapping: Record<string, string> = {}
|
return buildModelMappingPayload(
|
||||||
|
modelRestrictionMode.value,
|
||||||
if (modelRestrictionMode.value === 'whitelist') {
|
allowedModels.value,
|
||||||
for (const model of allowedModels.value) {
|
modelMappings.value
|
||||||
mapping[model] = model
|
)
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (const m of modelMappings.value) {
|
|
||||||
const from = m.from.trim()
|
|
||||||
const to = m.to.trim()
|
|
||||||
if (from && to) {
|
|
||||||
mapping[from] = to
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Object.keys(mapping).length > 0 ? mapping : null
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const buildUpdatePayload = (): Record<string, unknown> | null => {
|
const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||||
|
|||||||
@@ -916,8 +916,8 @@
|
|||||||
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
|
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Model Restriction Section (不适用于 Gemini,Antigravity 已在上层条件排除) -->
|
<!-- Model Restriction Section (Antigravity 已在上层条件排除) -->
|
||||||
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
@@ -1200,34 +1200,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Gemini 模型说明 -->
|
|
||||||
<div v-if="form.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
|
||||||
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
|
|
||||||
<div class="flex items-start gap-3">
|
|
||||||
<svg
|
|
||||||
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
stroke-width="2"
|
|
||||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<div>
|
|
||||||
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
|
|
||||||
{{ t('admin.accounts.gemini.modelPassthrough') }}
|
|
||||||
</p>
|
|
||||||
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
|
|
||||||
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Temp Unschedulable Rules -->
|
<!-- Temp Unschedulable Rules -->
|
||||||
@@ -1378,9 +1350,9 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Intercept Warmup Requests (Anthropic only) -->
|
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
|
||||||
<div
|
<div
|
||||||
v-if="form.platform === 'anthropic'"
|
v-if="form.platform === 'anthropic' || form.platform === 'antigravity'"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="flex items-center justify-between">
|
<div class="flex items-center justify-between">
|
||||||
@@ -2157,7 +2129,7 @@
|
|||||||
<ConfirmDialog
|
<ConfirmDialog
|
||||||
:show="showMixedChannelWarning"
|
:show="showMixedChannelWarning"
|
||||||
:title="t('admin.accounts.mixedChannelWarningTitle')"
|
:title="t('admin.accounts.mixedChannelWarningTitle')"
|
||||||
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
|
:message="mixedChannelWarningMessageText"
|
||||||
:confirm-text="t('common.confirm')"
|
:confirm-text="t('common.confirm')"
|
||||||
:cancel-text="t('common.cancel')"
|
:cancel-text="t('common.cancel')"
|
||||||
:danger="true"
|
:danger="true"
|
||||||
@@ -2189,13 +2161,21 @@ import {
|
|||||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||||
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
|
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
|
||||||
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
|
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
|
||||||
import type { Proxy, AdminGroup, AccountPlatform, AccountType } from '@/types'
|
import type {
|
||||||
|
Proxy,
|
||||||
|
AdminGroup,
|
||||||
|
AccountPlatform,
|
||||||
|
AccountType,
|
||||||
|
CheckMixedChannelResponse,
|
||||||
|
CreateAccountRequest
|
||||||
|
} from '@/types'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||||
|
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||||
@@ -2337,10 +2317,13 @@ const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>
|
|||||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
||||||
const geminiAIStudioOAuthEnabled = ref(false)
|
const geminiAIStudioOAuthEnabled = ref(false)
|
||||||
|
|
||||||
// Mixed channel warning dialog state
|
|
||||||
const showMixedChannelWarning = ref(false)
|
const showMixedChannelWarning = ref(false)
|
||||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
|
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
|
||||||
const pendingCreatePayload = ref<any>(null)
|
null
|
||||||
|
)
|
||||||
|
const mixedChannelWarningRawMessage = ref('')
|
||||||
|
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
|
||||||
|
const antigravityMixedChannelConfirmed = ref(false)
|
||||||
const showAdvancedOAuth = ref(false)
|
const showAdvancedOAuth = ref(false)
|
||||||
const showGeminiHelpDialog = ref(false)
|
const showGeminiHelpDialog = ref(false)
|
||||||
|
|
||||||
@@ -2378,6 +2361,13 @@ const isOpenAIModelRestrictionDisabled = computed(() =>
|
|||||||
form.platform === 'openai' && openaiPassthroughEnabled.value
|
form.platform === 'openai' && openaiPassthroughEnabled.value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const mixedChannelWarningMessageText = computed(() => {
|
||||||
|
if (mixedChannelWarningDetails.value) {
|
||||||
|
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
|
||||||
|
}
|
||||||
|
return mixedChannelWarningRawMessage.value
|
||||||
|
})
|
||||||
|
|
||||||
const geminiQuotaDocs = {
|
const geminiQuotaDocs = {
|
||||||
codeAssist: 'https://developers.google.com/gemini-code-assist/resources/quotas',
|
codeAssist: 'https://developers.google.com/gemini-code-assist/resources/quotas',
|
||||||
aiStudio: 'https://ai.google.dev/pricing',
|
aiStudio: 'https://ai.google.dev/pricing',
|
||||||
@@ -2544,8 +2534,8 @@ watch(
|
|||||||
antigravityModelMappings.value = []
|
antigravityModelMappings.value = []
|
||||||
antigravityModelRestrictionMode.value = 'mapping'
|
antigravityModelRestrictionMode.value = 'mapping'
|
||||||
}
|
}
|
||||||
// Reset Anthropic-specific settings when switching to other platforms
|
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||||
if (newPlatform !== 'anthropic') {
|
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||||
interceptWarmupRequests.value = false
|
interceptWarmupRequests.value = false
|
||||||
}
|
}
|
||||||
if (newPlatform === 'sora') {
|
if (newPlatform === 'sora') {
|
||||||
@@ -2794,6 +2784,105 @@ const splitTempUnschedKeywords = (value: string) => {
|
|||||||
.filter((item) => item.length > 0)
|
.filter((item) => item.length > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const needsMixedChannelCheck = (platform: AccountPlatform) => platform === 'antigravity' || platform === 'anthropic'
|
||||||
|
|
||||||
|
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
|
||||||
|
const details = resp?.details
|
||||||
|
if (!details) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
groupName: details.group_name || 'Unknown',
|
||||||
|
currentPlatform: details.current_platform || 'Unknown',
|
||||||
|
otherPlatform: details.other_platform || 'Unknown'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearMixedChannelDialog = () => {
|
||||||
|
showMixedChannelWarning.value = false
|
||||||
|
mixedChannelWarningDetails.value = null
|
||||||
|
mixedChannelWarningRawMessage.value = ''
|
||||||
|
mixedChannelWarningAction.value = null
|
||||||
|
}
|
||||||
|
|
||||||
|
const openMixedChannelDialog = (opts: {
|
||||||
|
response?: CheckMixedChannelResponse
|
||||||
|
message?: string
|
||||||
|
onConfirm: () => Promise<void>
|
||||||
|
}) => {
|
||||||
|
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
|
||||||
|
mixedChannelWarningRawMessage.value =
|
||||||
|
opts.message || opts.response?.message || t('admin.accounts.failedToCreate')
|
||||||
|
mixedChannelWarningAction.value = opts.onConfirm
|
||||||
|
showMixedChannelWarning.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const withAntigravityConfirmFlag = (payload: CreateAccountRequest): CreateAccountRequest => {
|
||||||
|
if (needsMixedChannelCheck(payload.platform) && antigravityMixedChannelConfirmed.value) {
|
||||||
|
return {
|
||||||
|
...payload,
|
||||||
|
confirm_mixed_channel_risk: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const cloned = { ...payload }
|
||||||
|
delete cloned.confirm_mixed_channel_risk
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
|
||||||
|
if (!needsMixedChannelCheck(form.platform)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (antigravityMixedChannelConfirmed.value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
||||||
|
platform: form.platform,
|
||||||
|
group_ids: form.group_ids
|
||||||
|
})
|
||||||
|
if (!result.has_risk) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
openMixedChannelDialog({
|
||||||
|
response: result,
|
||||||
|
onConfirm: async () => {
|
||||||
|
antigravityMixedChannelConfirmed.value = true
|
||||||
|
await onConfirm()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return false
|
||||||
|
} catch (error: any) {
|
||||||
|
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const submitCreateAccount = async (payload: CreateAccountRequest) => {
|
||||||
|
submitting.value = true
|
||||||
|
try {
|
||||||
|
await adminAPI.accounts.create(withAntigravityConfirmFlag(payload))
|
||||||
|
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||||
|
emit('created')
|
||||||
|
handleClose()
|
||||||
|
} catch (error: any) {
|
||||||
|
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck(form.platform)) {
|
||||||
|
openMixedChannelDialog({
|
||||||
|
message: error.response?.data?.message,
|
||||||
|
onConfirm: async () => {
|
||||||
|
antigravityMixedChannelConfirmed.value = true
|
||||||
|
await submitCreateAccount(payload)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||||
|
} finally {
|
||||||
|
submitting.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Methods
|
// Methods
|
||||||
const resetForm = () => {
|
const resetForm = () => {
|
||||||
step.value = 1
|
step.value = 1
|
||||||
@@ -2855,9 +2944,13 @@ const resetForm = () => {
|
|||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
|
antigravityMixedChannelConfirmed.value = false
|
||||||
|
clearMixedChannelDialog()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
|
antigravityMixedChannelConfirmed.value = false
|
||||||
|
clearMixedChannelDialog()
|
||||||
emit('close')
|
emit('close')
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2916,56 +3009,34 @@ const buildSoraExtra = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create account with mixed channel warning handling
|
// Helper function to create account with mixed channel warning handling
|
||||||
const doCreateAccount = async (payload: any) => {
|
const doCreateAccount = async (payload: CreateAccountRequest) => {
|
||||||
|
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||||
|
await submitCreateAccount(payload)
|
||||||
|
})
|
||||||
|
if (!canContinue) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
await submitCreateAccount(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle mixed channel warning confirmation
|
||||||
|
const handleMixedChannelConfirm = async () => {
|
||||||
|
const action = mixedChannelWarningAction.value
|
||||||
|
if (!action) {
|
||||||
|
clearMixedChannelDialog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clearMixedChannelDialog()
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
await adminAPI.accounts.create(payload)
|
await action()
|
||||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
|
||||||
emit('created')
|
|
||||||
handleClose()
|
|
||||||
} catch (error: any) {
|
|
||||||
// Handle 409 mixed_channel_warning - show confirmation dialog
|
|
||||||
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
|
|
||||||
const details = error.response.data.details || {}
|
|
||||||
mixedChannelWarningDetails.value = {
|
|
||||||
groupName: details.group_name || 'Unknown',
|
|
||||||
currentPlatform: details.current_platform || 'Unknown',
|
|
||||||
otherPlatform: details.other_platform || 'Unknown'
|
|
||||||
}
|
|
||||||
pendingCreatePayload.value = payload
|
|
||||||
showMixedChannelWarning.value = true
|
|
||||||
} else {
|
|
||||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
|
||||||
}
|
|
||||||
} finally {
|
} finally {
|
||||||
submitting.value = false
|
submitting.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle mixed channel warning confirmation
|
|
||||||
const handleMixedChannelConfirm = async () => {
|
|
||||||
showMixedChannelWarning.value = false
|
|
||||||
if (pendingCreatePayload.value) {
|
|
||||||
pendingCreatePayload.value.confirm_mixed_channel_risk = true
|
|
||||||
submitting.value = true
|
|
||||||
try {
|
|
||||||
await adminAPI.accounts.create(pendingCreatePayload.value)
|
|
||||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
|
||||||
emit('created')
|
|
||||||
handleClose()
|
|
||||||
} catch (error: any) {
|
|
||||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
|
||||||
} finally {
|
|
||||||
submitting.value = false
|
|
||||||
pendingCreatePayload.value = null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleMixedChannelCancel = () => {
|
const handleMixedChannelCancel = () => {
|
||||||
showMixedChannelWarning.value = false
|
clearMixedChannelDialog()
|
||||||
pendingCreatePayload.value = null
|
|
||||||
mixedChannelWarningDetails.value = null
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleSubmit = async () => {
|
const handleSubmit = async () => {
|
||||||
@@ -2975,6 +3046,12 @@ const handleSubmit = async () => {
|
|||||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||||
|
step.value = 2
|
||||||
|
})
|
||||||
|
if (!canContinue) {
|
||||||
|
return
|
||||||
|
}
|
||||||
step.value = 2
|
step.value = 2
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -3010,15 +3087,10 @@ const handleSubmit = async () => {
|
|||||||
credentials.model_mapping = antigravityModelMapping
|
credentials.model_mapping = antigravityModelMapping
|
||||||
}
|
}
|
||||||
|
|
||||||
submitting.value = true
|
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||||
try {
|
|
||||||
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
||||||
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
|
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
|
||||||
} catch (error: any) {
|
|
||||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
|
||||||
} finally {
|
|
||||||
submitting.value = false
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3059,10 +3131,7 @@ const handleSubmit = async () => {
|
|||||||
credentials.custom_error_codes = [...selectedErrorCodes.value]
|
credentials.custom_error_codes = [...selectedErrorCodes.value]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add intercept warmup requests setting
|
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||||
if (interceptWarmupRequests.value) {
|
|
||||||
credentials.intercept_warmup_requests = true
|
|
||||||
}
|
|
||||||
if (!applyTempUnschedConfig(credentials)) {
|
if (!applyTempUnschedConfig(credentials)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -3132,7 +3201,7 @@ const createAccountAndFinish = async (
|
|||||||
if (!applyTempUnschedConfig(credentials)) {
|
if (!applyTempUnschedConfig(credentials)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
await adminAPI.accounts.create({
|
await doCreateAccount({
|
||||||
name: form.name,
|
name: form.name,
|
||||||
notes: form.notes,
|
notes: form.notes,
|
||||||
platform,
|
platform,
|
||||||
@@ -3147,9 +3216,6 @@ const createAccountAndFinish = async (
|
|||||||
expires_at: form.expires_at,
|
expires_at: form.expires_at,
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
})
|
})
|
||||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
|
||||||
emit('created')
|
|
||||||
handleClose()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI OAuth 授权码兑换
|
// OpenAI OAuth 授权码兑换
|
||||||
@@ -3497,7 +3563,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
|
|||||||
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
|
||||||
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
|
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
|
||||||
await adminAPI.accounts.create({
|
const createPayload = withAntigravityConfirmFlag({
|
||||||
name: accountName,
|
name: accountName,
|
||||||
notes: form.notes,
|
notes: form.notes,
|
||||||
platform: 'antigravity',
|
platform: 'antigravity',
|
||||||
@@ -3512,6 +3578,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
|
|||||||
expires_at: form.expires_at,
|
expires_at: form.expires_at,
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
})
|
})
|
||||||
|
await adminAPI.accounts.create(createPayload)
|
||||||
successCount++
|
successCount++
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
failedCount++
|
failedCount++
|
||||||
@@ -3606,6 +3673,7 @@ const handleAntigravityExchange = async (authCode: string) => {
|
|||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
|
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
|
||||||
|
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||||
// Antigravity 只使用映射模式
|
// Antigravity 只使用映射模式
|
||||||
const antigravityModelMapping = buildModelMappingObject(
|
const antigravityModelMapping = buildModelMappingObject(
|
||||||
'mapping',
|
'mapping',
|
||||||
@@ -3677,10 +3745,8 @@ const handleAnthropicExchange = async (authCode: string) => {
|
|||||||
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||||
}
|
}
|
||||||
|
|
||||||
const credentials = {
|
const credentials: Record<string, unknown> = { ...tokenInfo }
|
||||||
...tokenInfo,
|
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||||
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
|
|
||||||
}
|
|
||||||
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
|
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
oauth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
@@ -3779,11 +3845,8 @@ const handleCookieAuth = async (sessionKey: string) => {
|
|||||||
|
|
||||||
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
|
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
|
||||||
// Merge interceptWarmupRequests into credentials
|
const credentials: Record<string, unknown> = { ...tokenInfo }
|
||||||
const credentials: Record<string, unknown> = {
|
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||||
...tokenInfo,
|
|
||||||
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
|
|
||||||
}
|
|
||||||
if (tempUnschedEnabled.value) {
|
if (tempUnschedEnabled.value) {
|
||||||
credentials.temp_unschedulable_enabled = true
|
credentials.temp_unschedulable_enabled = true
|
||||||
credentials.temp_unschedulable_rules = tempUnschedPayload
|
credentials.temp_unschedulable_rules = tempUnschedPayload
|
||||||
|
|||||||
@@ -65,8 +65,8 @@
|
|||||||
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Model Restriction Section (不适用于 Gemini 和 Antigravity) -->
|
<!-- Model Restriction Section (不适用于 Antigravity) -->
|
||||||
<div v-if="account.platform !== 'gemini' && account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div v-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
@@ -349,34 +349,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Gemini 模型说明 -->
|
|
||||||
<div v-if="account.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
|
||||||
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
|
|
||||||
<div class="flex items-start gap-3">
|
|
||||||
<svg
|
|
||||||
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
stroke-width="2"
|
|
||||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<div>
|
|
||||||
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
|
|
||||||
{{ t('admin.accounts.gemini.modelPassthrough') }}
|
|
||||||
</p>
|
|
||||||
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
|
|
||||||
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Upstream fields (only for upstream type) -->
|
<!-- Upstream fields (only for upstream type) -->
|
||||||
@@ -641,9 +613,9 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Intercept Warmup Requests (Anthropic only) -->
|
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
|
||||||
<div
|
<div
|
||||||
v-if="account?.platform === 'anthropic'"
|
v-if="account?.platform === 'anthropic' || account?.platform === 'antigravity'"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="flex items-center justify-between">
|
<div class="flex items-center justify-between">
|
||||||
@@ -1139,7 +1111,7 @@
|
|||||||
<ConfirmDialog
|
<ConfirmDialog
|
||||||
:show="showMixedChannelWarning"
|
:show="showMixedChannelWarning"
|
||||||
:title="t('admin.accounts.mixedChannelWarningTitle')"
|
:title="t('admin.accounts.mixedChannelWarningTitle')"
|
||||||
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
|
:message="mixedChannelWarningMessageText"
|
||||||
:confirm-text="t('common.confirm')"
|
:confirm-text="t('common.confirm')"
|
||||||
:cancel-text="t('common.cancel')"
|
:cancel-text="t('common.cancel')"
|
||||||
:danger="true"
|
:danger="true"
|
||||||
@@ -1154,7 +1126,7 @@ import { useI18n } from 'vue-i18n'
|
|||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { useAuthStore } from '@/stores/auth'
|
import { useAuthStore } from '@/stores/auth'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import type { Account, Proxy, AdminGroup } from '@/types'
|
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
@@ -1162,6 +1134,7 @@ import Icon from '@/components/icons/Icon.vue'
|
|||||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||||
|
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import {
|
import {
|
||||||
@@ -1233,10 +1206,13 @@ const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-mod
|
|||||||
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
|
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
|
||||||
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
|
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
|
||||||
|
|
||||||
// Mixed channel warning dialog state
|
|
||||||
const showMixedChannelWarning = ref(false)
|
const showMixedChannelWarning = ref(false)
|
||||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
|
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
|
||||||
const pendingUpdatePayload = ref<Record<string, unknown> | null>(null)
|
null
|
||||||
|
)
|
||||||
|
const mixedChannelWarningRawMessage = ref('')
|
||||||
|
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
|
||||||
|
const antigravityMixedChannelConfirmed = ref(false)
|
||||||
|
|
||||||
// Quota control state (Anthropic OAuth/SetupToken only)
|
// Quota control state (Anthropic OAuth/SetupToken only)
|
||||||
const windowCostEnabled = ref(false)
|
const windowCostEnabled = ref(false)
|
||||||
@@ -1297,6 +1273,13 @@ const defaultBaseUrl = computed(() => {
|
|||||||
return 'https://api.anthropic.com'
|
return 'https://api.anthropic.com'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const mixedChannelWarningMessageText = computed(() => {
|
||||||
|
if (mixedChannelWarningDetails.value) {
|
||||||
|
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
|
||||||
|
}
|
||||||
|
return mixedChannelWarningRawMessage.value
|
||||||
|
})
|
||||||
|
|
||||||
const form = reactive({
|
const form = reactive({
|
||||||
name: '',
|
name: '',
|
||||||
notes: '',
|
notes: '',
|
||||||
@@ -1326,6 +1309,11 @@ watch(
|
|||||||
() => props.account,
|
() => props.account,
|
||||||
(newAccount) => {
|
(newAccount) => {
|
||||||
if (newAccount) {
|
if (newAccount) {
|
||||||
|
antigravityMixedChannelConfirmed.value = false
|
||||||
|
showMixedChannelWarning.value = false
|
||||||
|
mixedChannelWarningDetails.value = null
|
||||||
|
mixedChannelWarningRawMessage.value = ''
|
||||||
|
mixedChannelWarningAction.value = null
|
||||||
form.name = newAccount.name
|
form.name = newAccount.name
|
||||||
form.notes = newAccount.notes || ''
|
form.notes = newAccount.notes || ''
|
||||||
form.proxy_id = newAccount.proxy_id
|
form.proxy_id = newAccount.proxy_id
|
||||||
@@ -1725,18 +1713,123 @@ function toPositiveNumber(value: unknown) {
|
|||||||
return Math.trunc(num)
|
return Math.trunc(num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const needsMixedChannelCheck = () => props.account?.platform === 'antigravity' || props.account?.platform === 'anthropic'
|
||||||
|
|
||||||
|
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
|
||||||
|
const details = resp?.details
|
||||||
|
if (!details) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
groupName: details.group_name || 'Unknown',
|
||||||
|
currentPlatform: details.current_platform || 'Unknown',
|
||||||
|
otherPlatform: details.other_platform || 'Unknown'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearMixedChannelDialog = () => {
|
||||||
|
showMixedChannelWarning.value = false
|
||||||
|
mixedChannelWarningDetails.value = null
|
||||||
|
mixedChannelWarningRawMessage.value = ''
|
||||||
|
mixedChannelWarningAction.value = null
|
||||||
|
}
|
||||||
|
|
||||||
|
const openMixedChannelDialog = (opts: {
|
||||||
|
response?: CheckMixedChannelResponse
|
||||||
|
message?: string
|
||||||
|
onConfirm: () => Promise<void>
|
||||||
|
}) => {
|
||||||
|
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
|
||||||
|
mixedChannelWarningRawMessage.value =
|
||||||
|
opts.message || opts.response?.message || t('admin.accounts.failedToUpdate')
|
||||||
|
mixedChannelWarningAction.value = opts.onConfirm
|
||||||
|
showMixedChannelWarning.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const withAntigravityConfirmFlag = (payload: Record<string, unknown>) => {
|
||||||
|
if (needsMixedChannelCheck() && antigravityMixedChannelConfirmed.value) {
|
||||||
|
return {
|
||||||
|
...payload,
|
||||||
|
confirm_mixed_channel_risk: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const cloned = { ...payload }
|
||||||
|
delete cloned.confirm_mixed_channel_risk
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
|
||||||
|
if (!needsMixedChannelCheck()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (antigravityMixedChannelConfirmed.value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (!props.account) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
||||||
|
platform: props.account.platform,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
account_id: props.account.id
|
||||||
|
})
|
||||||
|
if (!result.has_risk) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
openMixedChannelDialog({
|
||||||
|
response: result,
|
||||||
|
onConfirm: async () => {
|
||||||
|
antigravityMixedChannelConfirmed.value = true
|
||||||
|
await onConfirm()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return false
|
||||||
|
} catch (error: any) {
|
||||||
|
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const formatDateTimeLocal = formatDateTimeLocalInput
|
const formatDateTimeLocal = formatDateTimeLocalInput
|
||||||
const parseDateTimeLocal = parseDateTimeLocalInput
|
const parseDateTimeLocal = parseDateTimeLocalInput
|
||||||
|
|
||||||
// Methods
|
// Methods
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
|
antigravityMixedChannelConfirmed.value = false
|
||||||
|
clearMixedChannelDialog()
|
||||||
emit('close')
|
emit('close')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const submitUpdateAccount = async (accountID: number, updatePayload: Record<string, unknown>) => {
|
||||||
|
submitting.value = true
|
||||||
|
try {
|
||||||
|
const updatedAccount = await adminAPI.accounts.update(accountID, withAntigravityConfirmFlag(updatePayload))
|
||||||
|
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||||
|
emit('updated', updatedAccount)
|
||||||
|
handleClose()
|
||||||
|
} catch (error: any) {
|
||||||
|
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck()) {
|
||||||
|
openMixedChannelDialog({
|
||||||
|
message: error.response?.data?.message,
|
||||||
|
onConfirm: async () => {
|
||||||
|
antigravityMixedChannelConfirmed.value = true
|
||||||
|
await submitUpdateAccount(accountID, updatePayload)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||||
|
} finally {
|
||||||
|
submitting.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleSubmit = async () => {
|
const handleSubmit = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
const accountID = props.account.id
|
||||||
|
|
||||||
submitting.value = true
|
|
||||||
const updatePayload: Record<string, unknown> = { ...form }
|
const updatePayload: Record<string, unknown> = { ...form }
|
||||||
try {
|
try {
|
||||||
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
|
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
|
||||||
@@ -1768,7 +1861,6 @@ const handleSubmit = async () => {
|
|||||||
newCredentials.api_key = currentCredentials.api_key
|
newCredentials.api_key = currentCredentials.api_key
|
||||||
} else {
|
} else {
|
||||||
appStore.showError(t('admin.accounts.apiKeyIsRequired'))
|
appStore.showError(t('admin.accounts.apiKeyIsRequired'))
|
||||||
submitting.value = false
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1789,11 +1881,8 @@ const handleSubmit = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add intercept warmup requests setting
|
// Add intercept warmup requests setting
|
||||||
if (interceptWarmupRequests.value) {
|
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||||
newCredentials.intercept_warmup_requests = true
|
|
||||||
}
|
|
||||||
if (!applyTempUnschedConfig(newCredentials)) {
|
if (!applyTempUnschedConfig(newCredentials)) {
|
||||||
submitting.value = false
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1808,8 +1897,10 @@ const handleSubmit = async () => {
|
|||||||
newCredentials.api_key = editApiKey.value.trim()
|
newCredentials.api_key = editApiKey.value.trim()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add intercept warmup requests setting
|
||||||
|
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||||
|
|
||||||
if (!applyTempUnschedConfig(newCredentials)) {
|
if (!applyTempUnschedConfig(newCredentials)) {
|
||||||
submitting.value = false
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1819,13 +1910,8 @@ const handleSubmit = async () => {
|
|||||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||||
|
|
||||||
if (interceptWarmupRequests.value) {
|
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||||
newCredentials.intercept_warmup_requests = true
|
|
||||||
} else {
|
|
||||||
delete newCredentials.intercept_warmup_requests
|
|
||||||
}
|
|
||||||
if (!applyTempUnschedConfig(newCredentials)) {
|
if (!applyTempUnschedConfig(newCredentials)) {
|
||||||
submitting.value = false
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1955,52 +2041,36 @@ const handleSubmit = async () => {
|
|||||||
updatePayload.extra = newExtra
|
updatePayload.extra = newExtra
|
||||||
}
|
}
|
||||||
|
|
||||||
const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
|
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
await submitUpdateAccount(accountID, updatePayload)
|
||||||
emit('updated', updatedAccount)
|
})
|
||||||
handleClose()
|
if (!canContinue) {
|
||||||
} catch (error: any) {
|
return
|
||||||
// Handle 409 mixed_channel_warning - show confirmation dialog
|
|
||||||
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
|
|
||||||
const details = error.response.data.details || {}
|
|
||||||
mixedChannelWarningDetails.value = {
|
|
||||||
groupName: details.group_name || 'Unknown',
|
|
||||||
currentPlatform: details.current_platform || 'Unknown',
|
|
||||||
otherPlatform: details.other_platform || 'Unknown'
|
|
||||||
}
|
|
||||||
pendingUpdatePayload.value = updatePayload
|
|
||||||
showMixedChannelWarning.value = true
|
|
||||||
} else {
|
|
||||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
|
||||||
}
|
}
|
||||||
} finally {
|
|
||||||
submitting.value = false
|
await submitUpdateAccount(accountID, updatePayload)
|
||||||
|
} catch (error: any) {
|
||||||
|
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle mixed channel warning confirmation
|
// Handle mixed channel warning confirmation
|
||||||
const handleMixedChannelConfirm = async () => {
|
const handleMixedChannelConfirm = async () => {
|
||||||
showMixedChannelWarning.value = false
|
const action = mixedChannelWarningAction.value
|
||||||
if (pendingUpdatePayload.value && props.account) {
|
if (!action) {
|
||||||
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
|
clearMixedChannelDialog()
|
||||||
submitting.value = true
|
return
|
||||||
try {
|
}
|
||||||
const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
|
clearMixedChannelDialog()
|
||||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
submitting.value = true
|
||||||
emit('updated', updatedAccount)
|
try {
|
||||||
handleClose()
|
await action()
|
||||||
} catch (error: any) {
|
} finally {
|
||||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
submitting.value = false
|
||||||
} finally {
|
|
||||||
submitting.value = false
|
|
||||||
pendingUpdatePayload.value = null
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleMixedChannelCancel = () => {
|
const handleMixedChannelCancel = () => {
|
||||||
showMixedChannelWarning.value = false
|
clearMixedChannelDialog()
|
||||||
pendingUpdatePayload.value = null
|
|
||||||
mixedChannelWarningDetails.value = null
|
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
import { describe, it, expect } from 'vitest'
|
||||||
|
import { applyInterceptWarmup } from '../credentialsBuilder'
|
||||||
|
|
||||||
|
describe('applyInterceptWarmup', () => {
|
||||||
|
it('create + enabled=true: should set intercept_warmup_requests to true', () => {
|
||||||
|
const creds: Record<string, unknown> = { access_token: 'tok' }
|
||||||
|
applyInterceptWarmup(creds, true, 'create')
|
||||||
|
expect(creds.intercept_warmup_requests).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('create + enabled=false: should not add the field', () => {
|
||||||
|
const creds: Record<string, unknown> = { access_token: 'tok' }
|
||||||
|
applyInterceptWarmup(creds, false, 'create')
|
||||||
|
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('edit + enabled=true: should set intercept_warmup_requests to true', () => {
|
||||||
|
const creds: Record<string, unknown> = { api_key: 'sk' }
|
||||||
|
applyInterceptWarmup(creds, true, 'edit')
|
||||||
|
expect(creds.intercept_warmup_requests).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('edit + enabled=false + field exists: should delete the field', () => {
|
||||||
|
const creds: Record<string, unknown> = { api_key: 'sk', intercept_warmup_requests: true }
|
||||||
|
applyInterceptWarmup(creds, false, 'edit')
|
||||||
|
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('edit + enabled=false + field absent: should not throw', () => {
|
||||||
|
const creds: Record<string, unknown> = { api_key: 'sk' }
|
||||||
|
applyInterceptWarmup(creds, false, 'edit')
|
||||||
|
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not affect other fields', () => {
|
||||||
|
const creds: Record<string, unknown> = {
|
||||||
|
api_key: 'sk',
|
||||||
|
base_url: 'url',
|
||||||
|
intercept_warmup_requests: true
|
||||||
|
}
|
||||||
|
applyInterceptWarmup(creds, false, 'edit')
|
||||||
|
expect(creds.api_key).toBe('sk')
|
||||||
|
expect(creds.base_url).toBe('url')
|
||||||
|
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
11
frontend/src/components/account/credentialsBuilder.ts
Normal file
11
frontend/src/components/account/credentialsBuilder.ts
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
export function applyInterceptWarmup(
|
||||||
|
credentials: Record<string, unknown>,
|
||||||
|
enabled: boolean,
|
||||||
|
mode: 'create' | 'edit'
|
||||||
|
): void {
|
||||||
|
if (enabled) {
|
||||||
|
credentials.intercept_warmup_requests = true
|
||||||
|
} else if (mode === 'edit') {
|
||||||
|
delete credentials.intercept_warmup_requests
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -76,6 +76,7 @@ const antigravityModels = [
|
|||||||
// Claude 4.5+ 系列
|
// Claude 4.5+ 系列
|
||||||
'claude-opus-4-6',
|
'claude-opus-4-6',
|
||||||
'claude-opus-4-5-thinking',
|
'claude-opus-4-5-thinking',
|
||||||
|
'claude-sonnet-4-6',
|
||||||
'claude-sonnet-4-5',
|
'claude-sonnet-4-5',
|
||||||
'claude-sonnet-4-5-thinking',
|
'claude-sonnet-4-5-thinking',
|
||||||
// Gemini 2.5 系列
|
// Gemini 2.5 系列
|
||||||
@@ -88,6 +89,9 @@ const antigravityModels = [
|
|||||||
'gemini-3-pro-high',
|
'gemini-3-pro-high',
|
||||||
'gemini-3-pro-low',
|
'gemini-3-pro-low',
|
||||||
'gemini-3-pro-image',
|
'gemini-3-pro-image',
|
||||||
|
// Gemini 3.1 系列
|
||||||
|
'gemini-3.1-pro-high',
|
||||||
|
'gemini-3.1-pro-low',
|
||||||
// 其他
|
// 其他
|
||||||
'gpt-oss-120b-medium',
|
'gpt-oss-120b-medium',
|
||||||
'tab_flash_lite_preview'
|
'tab_flash_lite_preview'
|
||||||
@@ -287,14 +291,23 @@ const antigravityPresetMappings = [
|
|||||||
{ label: 'Sonnet→Sonnet', from: 'claude-sonnet-*', to: 'claude-sonnet-4-5', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
{ label: 'Sonnet→Sonnet', from: 'claude-sonnet-*', to: 'claude-sonnet-4-5', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||||
{ label: 'Opus→Opus', from: 'claude-opus-*', to: 'claude-opus-4-6-thinking', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
{ label: 'Opus→Opus', from: 'claude-opus-*', to: 'claude-opus-4-6-thinking', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
{ label: 'Haiku→Sonnet', from: 'claude-haiku-*', to: 'claude-sonnet-4-5', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
{ label: 'Haiku→Sonnet', from: 'claude-haiku-*', to: 'claude-sonnet-4-5', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||||
|
{ label: 'Sonnet4→4.6', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-6', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
|
||||||
|
{ label: 'Sonnet4.5→4.6', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||||
|
{ label: 'Sonnet3.5→4.6', from: 'claude-3-5-sonnet-20241022', to: 'claude-sonnet-4-6', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' },
|
||||||
|
{ label: 'Opus4.5→4.6', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-6-thinking', color: 'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400' },
|
||||||
// Gemini 3→3.1 映射
|
// Gemini 3→3.1 映射
|
||||||
{ label: '3-Pro-Preview→3.1-Pro-High', from: 'gemini-3-pro-preview', to: 'gemini-3.1-pro-high', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' },
|
{ label: '3-Pro-Preview→3.1-Pro-High', from: 'gemini-3-pro-preview', to: 'gemini-3.1-pro-high', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' },
|
||||||
{ label: '3-Pro-High→3.1-Pro-High', from: 'gemini-3-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
{ label: '3-Pro-High→3.1-Pro-High', from: 'gemini-3-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||||
{ label: '3-Pro-Low→3.1-Pro-Low', from: 'gemini-3-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
{ label: '3-Pro-Low→3.1-Pro-Low', from: 'gemini-3-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||||
|
{ label: '3.1-Pro-High透传', from: 'gemini-3.1-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||||
|
{ label: '3.1-Pro-Low透传', from: 'gemini-3.1-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||||
// Gemini 通配符映射
|
// Gemini 通配符映射
|
||||||
{ label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
{ label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||||
{ label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
{ label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||||
|
{ label: '3-Flash透传', from: 'gemini-3-flash', to: 'gemini-3-flash', color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' },
|
||||||
|
{ label: '2.5-Flash-Lite透传', from: 'gemini-2.5-flash-lite', to: 'gemini-2.5-flash-lite', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||||
// 精确映射
|
// 精确映射
|
||||||
|
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||||
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
|
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2047,7 +2047,7 @@ export default {
|
|||||||
gemini3Pro: 'G3P',
|
gemini3Pro: 'G3P',
|
||||||
gemini3Flash: 'G3F',
|
gemini3Flash: 'G3F',
|
||||||
gemini3Image: 'G3I',
|
gemini3Image: 'G3I',
|
||||||
claude45: 'C4.5'
|
claude: 'Claude'
|
||||||
},
|
},
|
||||||
tier: {
|
tier: {
|
||||||
free: 'Free',
|
free: 'Free',
|
||||||
|
|||||||
@@ -1583,7 +1583,7 @@ export default {
|
|||||||
gemini3Pro: 'G3P',
|
gemini3Pro: 'G3P',
|
||||||
gemini3Flash: 'G3F',
|
gemini3Flash: 'G3F',
|
||||||
gemini3Image: 'G3I',
|
gemini3Image: 'G3I',
|
||||||
claude45: 'C4.5'
|
claude: 'Claude'
|
||||||
},
|
},
|
||||||
tier: {
|
tier: {
|
||||||
free: 'Free',
|
free: 'Free',
|
||||||
|
|||||||
@@ -581,6 +581,7 @@ export interface GeminiCredentials {
|
|||||||
token_type?: string
|
token_type?: string
|
||||||
scope?: string
|
scope?: string
|
||||||
expires_at?: string
|
expires_at?: string
|
||||||
|
model_mapping?: Record<string, string>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TempUnschedulableRule {
|
export interface TempUnschedulableRule {
|
||||||
@@ -766,6 +767,26 @@ export interface UpdateAccountRequest {
|
|||||||
confirm_mixed_channel_risk?: boolean
|
confirm_mixed_channel_risk?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface CheckMixedChannelRequest {
|
||||||
|
platform: AccountPlatform
|
||||||
|
group_ids: number[]
|
||||||
|
account_id?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MixedChannelWarningDetails {
|
||||||
|
group_id: number
|
||||||
|
group_name: string
|
||||||
|
current_platform: string
|
||||||
|
other_platform: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CheckMixedChannelResponse {
|
||||||
|
has_risk: boolean
|
||||||
|
error?: string
|
||||||
|
message?: string
|
||||||
|
details?: MixedChannelWarningDetails
|
||||||
|
}
|
||||||
|
|
||||||
export interface CreateProxyRequest {
|
export interface CreateProxyRequest {
|
||||||
name: string
|
name: string
|
||||||
protocol: ProxyProtocol
|
protocol: ProxyProtocol
|
||||||
|
|||||||
@@ -1,18 +1,13 @@
|
|||||||
import { defineConfig } from 'vitest/config'
|
import { defineConfig } from 'vitest/config'
|
||||||
import vue from '@vitejs/plugin-vue'
|
|
||||||
import { resolve } from 'path'
|
import { resolve } from 'path'
|
||||||
|
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
plugins: [vue()],
|
|
||||||
resolve: {
|
resolve: {
|
||||||
alias: {
|
alias: {
|
||||||
'@': resolve(__dirname, 'src'),
|
'@': resolve(__dirname, 'src'),
|
||||||
'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js'
|
'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
define: {
|
|
||||||
__INTLIFY_JIT_COMPILATION__: true
|
|
||||||
},
|
|
||||||
test: {
|
test: {
|
||||||
globals: true,
|
globals: true,
|
||||||
environment: 'jsdom',
|
environment: 'jsdom',
|
||||||
@@ -37,8 +32,6 @@ export default defineConfig({
|
|||||||
lines: 80
|
lines: 80
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
setupFiles: ['./src/__tests__/setup.ts'],
|
|
||||||
testTimeout: 10000
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user