mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 00:40:22 +08:00
Compare commits
87 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bae525026 | ||
|
|
df00805a2a | ||
|
|
a88ee96518 | ||
|
|
3cc2f9bd57 | ||
|
|
d1b684b782 | ||
|
|
6460d4ad3a | ||
|
|
19ea392d5d | ||
|
|
fb4d016176 | ||
|
|
afec747d9e | ||
|
|
7388fcce41 | ||
|
|
a6f9f9f968 | ||
|
|
29759721e0 | ||
|
|
1941b20521 | ||
|
|
e6969acb50 | ||
|
|
9489531431 | ||
|
|
32b7c0ca9b | ||
|
|
4ac57b4edf | ||
|
|
685a1e0ba3 | ||
|
|
e350aab1bd | ||
|
|
0dd6986e28 | ||
|
|
6d0102a70c | ||
|
|
f96a2a18c1 | ||
|
|
f955b04a6f | ||
|
|
2fd6ac319b | ||
|
|
82fbf452a8 | ||
|
|
ba69736f55 | ||
|
|
c75c6b6858 | ||
|
|
de61745bb2 | ||
|
|
3fab0fcd4c | ||
|
|
03bcd94ae5 | ||
|
|
0343bc7777 | ||
|
|
565d19acfd | ||
|
|
960acf1982 | ||
|
|
ece911521e | ||
|
|
5d95e59742 | ||
|
|
01d084bbfd | ||
|
|
7918fc2844 | ||
|
|
31b30a6df2 | ||
|
|
d217b59e0b | ||
|
|
169a4b9d32 | ||
|
|
15f3ffb165 | ||
|
|
02db1010dd | ||
|
|
935ea66681 | ||
|
|
26060e702f | ||
|
|
65d4ca2563 | ||
|
|
3c619a8da5 | ||
|
|
ded9b6c14e | ||
|
|
609abbbd7c | ||
|
|
1b4e504fad | ||
|
|
0a3a445828 | ||
|
|
c7e18bd5be | ||
|
|
083d202fe4 | ||
|
|
8365a8328b | ||
|
|
58f21e4b3a | ||
|
|
5bd7408b2f | ||
|
|
c671e8dd1d | ||
|
|
a3aed3c4c3 | ||
|
|
c008649584 | ||
|
|
516f8f287c | ||
|
|
66148690c6 | ||
|
|
cadd7f546f | ||
|
|
a3ff317f1c | ||
|
|
d8d4b0c0c7 | ||
|
|
d616f8c854 | ||
|
|
b6fa8b8eec | ||
|
|
36d2e6999b | ||
|
|
076c00063d | ||
|
|
ea8104c6a2 | ||
|
|
ca3e9336e1 | ||
|
|
f92ab48166 | ||
|
|
c10267ce2b | ||
|
|
9bd6a62ab3 | ||
|
|
0dbea6ca58 | ||
|
|
6523b23221 | ||
|
|
29c406dda0 | ||
|
|
483c8f246d | ||
|
|
645f283108 | ||
|
|
da6fd45000 | ||
|
|
fb3ef5f388 | ||
|
|
86bc76e352 | ||
|
|
644058174e | ||
|
|
4573868c08 | ||
|
|
09166a52f8 | ||
|
|
aaac1aaca9 | ||
|
|
59898c16c6 | ||
|
|
0dacdf480b | ||
|
|
fdf9f68298 |
25
DEV_GUIDE.md
25
DEV_GUIDE.md
@@ -209,7 +209,30 @@ git add ent/ # 生成的文件也要提交
|
||||
|
||||
---
|
||||
|
||||
### 坑 10:PR 提交前检查清单
|
||||
### 坑 10:前端测试看似正常,但后端调用失败(模型映射被批量误改)
|
||||
|
||||
**典型现象**:
|
||||
- 前端按钮点测看起来正常;
|
||||
- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号;
|
||||
- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。
|
||||
|
||||
**根因**:
|
||||
- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”;
|
||||
- 但在**批量修改同时选中不同平台账号**(OpenAI + Antigravity/Gemini)时,模型白名单/映射可能被跨平台策略覆盖;
|
||||
- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。
|
||||
|
||||
**修复方案(按优先级)**:
|
||||
1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。
|
||||
2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。
|
||||
|
||||
**关键经验**:
|
||||
- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传;
|
||||
- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案;
|
||||
- 批量操作前尽量按平台分组,不要混选不同平台账号。
|
||||
|
||||
---
|
||||
|
||||
### 坑 11:PR 提交前检查清单
|
||||
|
||||
提交 PR 前务必本地验证:
|
||||
|
||||
|
||||
@@ -113,7 +113,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
@@ -187,9 +188,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
|
||||
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig)
|
||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.25.7
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/dgraph-io/ristretto v0.2.0
|
||||
@@ -18,7 +19,7 @@ require (
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/refraction-networking/utls v1.8.1
|
||||
github.com/refraction-networking/utls v1.8.2
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/shirou/gopsutil/v4 v4.25.6
|
||||
github.com/spf13/viper v1.18.2
|
||||
@@ -29,10 +30,10 @@ require (
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/net v0.49.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.39.0
|
||||
golang.org/x/term v0.40.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.44.3
|
||||
@@ -46,7 +47,14 @@ require (
|
||||
github.com/agext/levenshtein v1.2.3 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
github.com/bogdanfinn/fhttp v0.6.8 // indirect
|
||||
github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect
|
||||
github.com/bogdanfinn/tls-client v1.14.0 // indirect
|
||||
github.com/bogdanfinn/utls v1.7.7-barnius // indirect
|
||||
github.com/bogdanfinn/websocket v1.5.5-barnius // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
@@ -79,7 +87,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
@@ -124,6 +131,7 @@ require (
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
@@ -145,10 +153,9 @@ require (
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
golang.org/x/mod v0.32.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
|
||||
@@ -10,6 +10,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU=
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
|
||||
@@ -20,10 +22,24 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
|
||||
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
|
||||
github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4=
|
||||
github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M=
|
||||
github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s=
|
||||
github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg=
|
||||
github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A=
|
||||
github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM=
|
||||
github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU=
|
||||
github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg=
|
||||
github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI=
|
||||
github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@@ -120,8 +136,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -176,8 +190,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -211,8 +223,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -238,12 +248,10 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1
|
||||
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -266,8 +274,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
@@ -289,6 +295,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
|
||||
@@ -355,18 +363,21 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -374,16 +385,19 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
|
||||
@@ -1088,9 +1088,9 @@ func setDefaults() {
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256")
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
|
||||
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.data_dir", "./data")
|
||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||
@@ -1158,6 +1158,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.force_codex_cli", false)
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
|
||||
@@ -74,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
|
||||
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
|
||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
// Claude 详细版本 ID 映射
|
||||
@@ -88,17 +89,21 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3.1-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3.1-pro-low",
|
||||
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||
// Gemini 3.1 透传
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||
// Gemini 3 preview 映射
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
// Gemini 3.1 白名单
|
||||
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||
// Gemini 3 preview 映射
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3.1-pro-high",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
// Gemini 3.1 preview 映射
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||
// Gemini 3.1 image 白名单
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
// Gemini 3.1 image preview 映射
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
// 其他官方模型
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -139,6 +140,13 @@ type BulkUpdateAccountsRequest struct {
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||
type CheckMixedChannelRequest struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*dto.Account
|
||||
@@ -389,6 +397,50 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
|
||||
// POST /api/v1/admin/accounts/check-mixed-channel
|
||||
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
|
||||
var req CheckMixedChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.GroupIDs) == 0 {
|
||||
response.Success(c, gin.H{"has_risk": false})
|
||||
return
|
||||
}
|
||||
|
||||
accountID := int64(0)
|
||||
if req.AccountID != nil {
|
||||
accountID = *req.AccountID
|
||||
}
|
||||
|
||||
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
|
||||
if err != nil {
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
response.Success(c, gin.H{
|
||||
"has_risk": true,
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"has_risk": false})
|
||||
}
|
||||
|
||||
// Create handles creating a new account
|
||||
// POST /api/v1/admin/accounts
|
||||
func (h *AccountHandler) Create(c *gin.Context) {
|
||||
@@ -431,17 +483,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -501,17 +546,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -1422,32 +1460,8 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle Antigravity accounts: return Claude + Gemini models
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
// Antigravity 支持 Claude 和部分 Gemini 模型
|
||||
type UnifiedModel struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
var models []UnifiedModel
|
||||
|
||||
// 添加 Claude 模型
|
||||
for _, m := range claude.DefaultModels {
|
||||
models = append(models, UnifiedModel{
|
||||
ID: m.ID,
|
||||
Type: m.Type,
|
||||
DisplayName: m.DisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
// 添加 Gemini 3 系列模型用于测试
|
||||
geminiTestModels := []UnifiedModel{
|
||||
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"},
|
||||
}
|
||||
models = append(models, geminiTestModels...)
|
||||
|
||||
response.Success(c, models)
|
||||
// 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
|
||||
response.Success(c, antigravity.DefaultModels())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
|
||||
router.POST("/api/v1/admin/accounts", accountHandler.Create)
|
||||
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"platform": "antigravity",
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, false, data["has_risk"])
|
||||
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
|
||||
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
|
||||
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
|
||||
}
|
||||
|
||||
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.checkMixedErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"platform": "antigravity",
|
||||
"group_ids": []int64{27},
|
||||
"account_id": 99,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["has_risk"])
|
||||
require.Equal(t, "mixed_channel_warning", data["error"])
|
||||
details, ok := data["details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(27), details["group_id"])
|
||||
require.Equal(t, "claude-max", details["group_name"])
|
||||
require.Equal(t, "Antigravity", details["current_platform"])
|
||||
require.Equal(t, "Anthropic", details["other_platform"])
|
||||
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.createAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"name": "ag-oauth-1",
|
||||
"platform": "antigravity",
|
||||
"type": "oauth",
|
||||
"credentials": map[string]any{"refresh_token": "rt"},
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
require.False(t, hasRequireConfirmation)
|
||||
}
|
||||
|
||||
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.updateAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
require.False(t, hasRequireConfirmation)
|
||||
}
|
||||
@@ -10,19 +10,27 @@ import (
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
mu sync.Mutex
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
createAccountErr error
|
||||
updateAccountErr error
|
||||
checkMixedErr error
|
||||
lastMixedCheck struct {
|
||||
accountID int64
|
||||
platform string
|
||||
groupIDs []int64
|
||||
}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
@@ -188,11 +196,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
|
||||
s.mu.Lock()
|
||||
s.createdAccounts = append(s.createdAccounts, input)
|
||||
s.mu.Unlock()
|
||||
if s.createAccountErr != nil {
|
||||
return nil, s.createAccountErr
|
||||
}
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
if s.updateAccountErr != nil {
|
||||
return nil, s.updateAccountErr
|
||||
}
|
||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -224,6 +238,13 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
|
||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
s.lastMixedCheck.accountID = currentAccountID
|
||||
s.lastMixedCheck.platform = currentAccountPlatform
|
||||
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
|
||||
return s.checkMixedErr
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
search = strings.TrimSpace(strings.ToLower(search))
|
||||
filtered := make([]service.Proxy, 0, len(s.proxies))
|
||||
|
||||
@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
|
||||
if strings.Contains(msg, "OAuth client not configured") ||
|
||||
strings.Contains(msg, "requires your own OAuth Client") ||
|
||||
strings.Contains(msg, "requires a custom OAuth Client") ||
|
||||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
|
||||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
|
||||
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
160
backend/internal/handler/failover_loop.go
Normal file
160
backend/internal/handler/failover_loop.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||
// GatewayService 隐式实现此接口。
|
||||
type TempUnscheduler interface {
|
||||
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
|
||||
}
|
||||
|
||||
// FailoverAction 表示 failover 错误处理后的下一步动作
|
||||
type FailoverAction int
|
||||
|
||||
const (
|
||||
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
|
||||
FailoverContinue FailoverAction = iota
|
||||
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
|
||||
FailoverExhausted
|
||||
// FailoverCanceled context 已取消(调用方应直接 return)
|
||||
FailoverCanceled
|
||||
)
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
singleAccountBackoffDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
// FailoverState 跨循环迭代共享的 failover 状态
|
||||
type FailoverState struct {
|
||||
SwitchCount int
|
||||
MaxSwitches int
|
||||
FailedAccountIDs map[int64]struct{}
|
||||
SameAccountRetryCount map[int64]int
|
||||
LastFailoverErr *service.UpstreamFailoverError
|
||||
ForceCacheBilling bool
|
||||
hasBoundSession bool
|
||||
}
|
||||
|
||||
// NewFailoverState 创建 failover 状态
|
||||
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
|
||||
return &FailoverState{
|
||||
MaxSwitches: maxSwitches,
|
||||
FailedAccountIDs: make(map[int64]struct{}),
|
||||
SameAccountRetryCount: make(map[int64]int),
|
||||
hasBoundSession: hasBoundSession,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
|
||||
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
|
||||
func (s *FailoverState) HandleFailoverError(
|
||||
ctx context.Context,
|
||||
gatewayService TempUnscheduler,
|
||||
accountID int64,
|
||||
platform string,
|
||||
failoverErr *service.UpstreamFailoverError,
|
||||
) FailoverAction {
|
||||
s.LastFailoverErr = failoverErr
|
||||
|
||||
// 缓存计费判断
|
||||
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
|
||||
s.ForceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||
s.SameAccountRetryCount[accountID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
||||
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
|
||||
}
|
||||
|
||||
// 加入失败列表
|
||||
s.FailedAccountIDs[accountID] = struct{}{}
|
||||
|
||||
// 检查是否耗尽
|
||||
if s.SwitchCount >= s.MaxSwitches {
|
||||
return FailoverExhausted
|
||||
}
|
||||
|
||||
// 递增切换计数
|
||||
s.SwitchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
||||
|
||||
// Antigravity 平台换号线性递增延时
|
||||
if platform == service.PlatformAntigravity {
|
||||
delay := time.Duration(s.SwitchCount-1) * time.Second
|
||||
if !sleepWithContext(ctx, delay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
}
|
||||
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
|
||||
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
|
||||
// 清除排除列表、等待退避后重新选号。
|
||||
//
|
||||
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
|
||||
// 返回 FailoverExhausted 时,调用方应返回错误响应。
|
||||
// 返回 FailoverCanceled 时,调用方应直接 return。
|
||||
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
|
||||
if s.LastFailoverErr != nil &&
|
||||
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||
s.SwitchCount <= s.MaxSwitches {
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
||||
singleAccountBackoffDelay, s.SwitchCount)
|
||||
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
||||
s.SwitchCount, s.MaxSwitches)
|
||||
s.FailedAccountIDs = make(map[int64]struct{})
|
||||
return FailoverContinue
|
||||
}
|
||||
return FailoverExhausted
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(d):
|
||||
return true
|
||||
}
|
||||
}
|
||||
732
backend/internal/handler/failover_loop_test.go
Normal file
732
backend/internal/handler/failover_loop_test.go
Normal file
@@ -0,0 +1,732 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
|
||||
type mockTempUnscheduler struct {
|
||||
calls []tempUnscheduleCall
|
||||
}
|
||||
|
||||
type tempUnscheduleCall struct {
|
||||
accountID int64
|
||||
failoverErr *service.UpstreamFailoverError
|
||||
}
|
||||
|
||||
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
|
||||
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
|
||||
return &service.UpstreamFailoverError{
|
||||
StatusCode: statusCode,
|
||||
RetryableOnSameAccount: retryable,
|
||||
ForceCacheBilling: forceBilling,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NewFailoverState 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewFailoverState(t *testing.T) {
|
||||
t.Run("初始化字段正确", func(t *testing.T) {
|
||||
fs := NewFailoverState(5, true)
|
||||
require.Equal(t, 5, fs.MaxSwitches)
|
||||
require.Equal(t, 0, fs.SwitchCount)
|
||||
require.NotNil(t, fs.FailedAccountIDs)
|
||||
require.Empty(t, fs.FailedAccountIDs)
|
||||
require.NotNil(t, fs.SameAccountRetryCount)
|
||||
require.Empty(t, fs.SameAccountRetryCount)
|
||||
require.Nil(t, fs.LastFailoverErr)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
require.True(t, fs.hasBoundSession)
|
||||
})
|
||||
|
||||
t.Run("无绑定会话", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
require.Equal(t, 3, fs.MaxSwitches)
|
||||
require.False(t, fs.hasBoundSession)
|
||||
})
|
||||
|
||||
t.Run("零最大切换次数", func(t *testing.T) {
|
||||
fs := NewFailoverState(0, false)
|
||||
require.Equal(t, 0, fs.MaxSwitches)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepWithContext 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepWithContext(t *testing.T) {
|
||||
t.Run("零时长立即返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), 0)
|
||||
require.True(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("负时长立即返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), -1*time.Second)
|
||||
require.True(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("正常等待后返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok)
|
||||
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
|
||||
require.Less(t, elapsed, 500*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("已取消context立即返回false", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(ctx, 5*time.Second)
|
||||
require.False(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("等待期间context取消返回false", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(ctx, 5*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
require.False(t, ok)
|
||||
require.Less(t, elapsed, 500*time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 基本切换流程
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
|
||||
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
require.Equal(t, err, fs.LastFailoverErr)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
|
||||
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
|
||||
})
|
||||
|
||||
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
|
||||
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1 // 模拟已切换一次
|
||||
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
|
||||
require.Less(t, elapsed, 3*time.Second)
|
||||
})
|
||||
|
||||
t.Run("连续切换直到耗尽", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(2, false)
|
||||
|
||||
// 第一次切换:0→1
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
|
||||
// 第二次切换:1→2
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
|
||||
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
|
||||
err3 := newTestFailoverErr(503, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
|
||||
|
||||
// 验证失败账号列表
|
||||
require.Len(t, fs.FailedAccountIDs, 3)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(300))
|
||||
|
||||
// LastFailoverErr 应为最后一次的错误
|
||||
require.Equal(t, err3, fs.LastFailoverErr)
|
||||
})
|
||||
|
||||
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(0, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Equal(t, 0, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_CacheBilling(t *testing.T) {
|
||||
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("两者均为false时不设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
// 第一次:ForceCacheBilling=true → 设置
|
||||
err1 := newTestFailoverErr(500, false, true)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
|
||||
// 第二次:ForceCacheBilling=false → 仍然保持 true
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
|
||||
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
|
||||
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
|
||||
// 验证等待了 sameAccountRetryDelay (500ms)
|
||||
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
|
||||
// 验证 TempUnschedule 被调用
|
||||
require.Len(t, mock.calls, 1)
|
||||
require.Equal(t, int64(100), mock.calls[0].accountID)
|
||||
require.Equal(t, err, mock.calls[0].failoverErr)
|
||||
})
|
||||
|
||||
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 账号 100 第一次重试
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 账号 200 第一次重试(独立计数)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[200])
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
|
||||
})
|
||||
|
||||
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — TempUnschedule 调用验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Empty(t, mock.calls)
|
||||
})
|
||||
|
||||
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
require.Equal(t, int64(42), mock.calls[0].accountID)
|
||||
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
|
||||
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — Context 取消
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
|
||||
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||
// 重试计数仍应递增
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
})
|
||||
|
||||
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — FailedAccountIDs 跟踪
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
|
||||
t.Run("切换时添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||
require.Len(t, fs.FailedAccountIDs, 2)
|
||||
})
|
||||
|
||||
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(0, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
|
||||
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.NotContains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
|
||||
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — LastFailoverErr 更新
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
|
||||
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.Equal(t, err1, fs.LastFailoverErr)
|
||||
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.Equal(t, err2, fs.LastFailoverErr)
|
||||
})
|
||||
|
||||
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, err, fs.LastFailoverErr)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 综合集成场景
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
// 3. 账号 200 遇到不可重试错误 → 直接切换
|
||||
switchErr := newTestFailoverErr(500, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
|
||||
// 4. 账号 300 遇到不可重试错误 → 再切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 3, fs.SwitchCount)
|
||||
|
||||
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
|
||||
// 最终状态验证
|
||||
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
|
||||
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(2, false)
|
||||
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
// 第一次切换:delay = 0s
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
|
||||
|
||||
// 第二次切换:delay = 1s
|
||||
start = time.Now()
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||
elapsed = time.Since(start)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
|
||||
|
||||
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
|
||||
start = time.Now()
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
|
||||
elapsed = time.Since(start)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
|
||||
})
|
||||
|
||||
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false) // hasBoundSession=false
|
||||
|
||||
// 第一次:ForceCacheBilling=false
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
|
||||
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
|
||||
err2 := newTestFailoverErr(500, false, true)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
|
||||
|
||||
// 第三次:ForceCacheBilling=false,但状态仍保持 true
|
||||
err3 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||
require.True(t, fs.ForceCacheBilling, "不应重置")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 边界条件
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_EdgeCases(t *testing.T) {
|
||||
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(0, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
})
|
||||
|
||||
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, true, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[0])
|
||||
})
|
||||
|
||||
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, true, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
|
||||
})
|
||||
|
||||
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleSelectionExhausted 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleSelectionExhausted(t *testing.T) {
|
||||
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
// LastFailoverErr 为 nil
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
})
|
||||
|
||||
t.Run("非503错误返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
})
|
||||
|
||||
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.FailedAccountIDs[100] = struct{}{}
|
||||
fs.SwitchCount = 1
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(2, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.SwitchCount = 3 // > MaxSwitches(2)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
|
||||
})
|
||||
|
||||
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(ctx)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||
})
|
||||
|
||||
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
|
||||
fs := NewFailoverState(2, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
})
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
lastFailoverErr = failoverErr
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
case FailoverExhausted:
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
|
||||
retryWithFallback := false
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
lastFailoverErr = failoverErr
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
case FailoverExhausted:
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -733,11 +650,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("gateway.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
)
|
||||
return
|
||||
}
|
||||
if !retryWithFallback {
|
||||
@@ -982,69 +894,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
||||
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||
delay := time.Duration(switchCount-1) * time.Second
|
||||
if delay <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
||||
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
||||
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
||||
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
||||
// 固定短延时:2s
|
||||
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
const delay = 2 * time.Second
|
||||
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.failover"),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Int("retry_count", retryCount),
|
||||
).Info("gateway.single_account_backoff_waiting")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepAntigravitySingleAccountBackoff 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok, "should return true when context is not canceled")
|
||||
// 固定延迟 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
||||
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.False(t, ok, "should return false when context is canceled")
|
||||
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
||||
// 验证不同 retryCount 都使用固定 2s 延迟
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok)
|
||||
// 即使 retryCount=5,延迟仍然是固定的 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
|
||||
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
|
||||
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
|
||||
|
||||
type fakeSchedulerCache struct {
|
||||
accounts []*service.Account
|
||||
}
|
||||
|
||||
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return f.accounts, true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
|
||||
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
|
||||
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
|
||||
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
|
||||
|
||||
type fakeGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
|
||||
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeConcurrencyCache struct{}
|
||||
|
||||
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
|
||||
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
|
||||
schedulerCache := &fakeSchedulerCache{accounts: accounts}
|
||||
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
|
||||
|
||||
gwSvc := service.NewGatewayService(
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache (disable sticky)
|
||||
nil, // cfg
|
||||
schedulerSnapshot,
|
||||
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
|
||||
nil, // billingService
|
||||
nil, // rateLimitService
|
||||
nil, // billingCacheService
|
||||
nil, // identityService
|
||||
nil, // httpUpstream
|
||||
nil, // deferredService
|
||||
nil, // claudeTokenProvider
|
||||
nil, // sessionLimitCache
|
||||
nil, // digestStore
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
h := &GatewayHandler{
|
||||
gatewayService: gwSvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
concurrencyHelper: concurrencyHelper,
|
||||
// 这些字段对本测试不敏感,保持较小即可
|
||||
maxAccountSwitches: 1,
|
||||
maxAccountSwitchesGemini: 1,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
billingCacheSvc.Stop()
|
||||
}
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2001)
|
||||
accountID := int64(1001)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
|
||||
c.Request = req
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3001,
|
||||
UserID: 4001,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4001,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
|
||||
content, ok := resp["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
first, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "New Conversation", first["text"])
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2002)
|
||||
accountID := int64(1002)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-2",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
|
||||
// - 写入 request.Context(Service读取)
|
||||
// - 写入 gin.Context(Handler快速读取)
|
||||
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
|
||||
req = req.WithContext(ctx)
|
||||
c.Request = req
|
||||
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3002,
|
||||
UserID: 4002,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4002,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
}
|
||||
@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
reqLog.Warn("gemini.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||
return
|
||||
}
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||
@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverErr = failoverErr
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch failoverAction {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
reqLog.Warn("gemini.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
IPAddress: clientIP,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
})
|
||||
reqLog.Debug("gemini.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("switch_count", fs.SwitchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,6 +151,8 @@ var claudeModels = []modelDef{
|
||||
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
@@ -161,6 +163,8 @@ var geminiModels = []modelDef{
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
@@ -400,7 +400,9 @@ func TestShouldFallbackToNextURL_无错误且200(t *testing.T) {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClient_ExchangeCode_成功(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 验证请求方法
|
||||
@@ -493,7 +495,9 @@ func TestClient_ExchangeCode_成功(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
client := NewClient("")
|
||||
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
|
||||
@@ -506,7 +510,9 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -531,7 +537,9 @@ func TestClient_ExchangeCode_服务器返回错误(t *testing.T) {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClient_RefreshToken_MockServer(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -590,7 +598,9 @@ func TestClient_RefreshToken_MockServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
client := NewClient("")
|
||||
_, err := client.RefreshToken(context.Background(), "refresh-tok")
|
||||
@@ -784,7 +794,9 @@ func newTestClientWithRedirect(redirects map[string]string) *Client {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -853,7 +865,9 @@ func TestClient_ExchangeCode_Success_RealCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -878,7 +892,9 @@ func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -901,7 +917,9 @@ func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(5 * time.Second) // 模拟慢响应
|
||||
@@ -927,7 +945,9 @@ func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -976,7 +996,9 @@ func TestClient_RefreshToken_Success_RealCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@@ -998,7 +1020,9 @@ func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -1021,7 +1045,9 @@ func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = "test-secret"
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
@@ -23,11 +23,9 @@ const (
|
||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = ""
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
|
||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||
// 出于安全原因,该值不得硬编码入库。
|
||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
@@ -51,14 +49,21 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.84.2
|
||||
var defaultUserAgentVersion = "1.84.2"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
|
||||
defaultUserAgentVersion = version
|
||||
}
|
||||
// 从环境变量读取 client_secret,未设置则使用默认值
|
||||
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
|
||||
defaultClientSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
@@ -67,14 +72,9 @@ func GetUserAgent() string {
|
||||
}
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if v := strings.TrimSpace(ClientSecret); v != "" {
|
||||
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
|
||||
if vv := strings.TrimSpace(v); vv != "" {
|
||||
return vv, nil
|
||||
}
|
||||
}
|
||||
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -17,8 +18,14 @@ import (
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||
|
||||
// 需要重新触发 init 逻辑:手动从环境变量读取
|
||||
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
@@ -29,11 +36,13 @@ func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量为空时应返回错误")
|
||||
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||
@@ -41,30 +50,31 @@ func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
|
||||
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
|
||||
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
|
||||
|
||||
// 明确设置再取消,确保环境变量不存在
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量未设置时应返回错误")
|
||||
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = " "
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量仅含空格时应返回错误")
|
||||
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = " valid-secret "
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
@@ -670,13 +680,17 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||
}
|
||||
if ClientSecret != "" {
|
||||
t.Error("ClientSecret 应为空字符串")
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||
}
|
||||
if secret != "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" {
|
||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.84.2 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -206,6 +206,7 @@ type modelInfo struct {
|
||||
var modelInfoMap = map[string]modelInfo{
|
||||
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
||||
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
||||
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
|
||||
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
|
||||
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
|
||||
}
|
||||
|
||||
@@ -11,8 +11,13 @@ const (
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
BetaContext1M = "context-1m-2025-08-07"
|
||||
BetaFastMode = "fast-mode-2026-02-01"
|
||||
)
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
|
||||
@@ -38,10 +38,8 @@ const (
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
// GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
|
||||
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
|
||||
GeminiCLIOAuthClientSecret = ""
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
@@ -408,11 +408,10 @@ func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
||||
// 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错
|
||||
func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
@@ -420,8 +419,11 @@ func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误")
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err)
|
||||
}
|
||||
if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) {
|
||||
t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,15 +687,17 @@ func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||
// 不设置环境变量且不提供凭据,应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("没有内置 secret 且未提供凭据时应该报错")
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||
t.Error("ClientSecret 不应为空")
|
||||
}
|
||||
if cfg.ClientID != GeminiCLIOAuthClientID {
|
||||
t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
9
backend/internal/repository/gemini_drive_client.go
Normal file
9
backend/internal/repository/gemini_drive_client.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package repository
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
|
||||
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
|
||||
// Returned as geminicli.DriveClient interface for DI (Strategy A).
|
||||
func NewGeminiDriveClient() geminicli.DriveClient {
|
||||
return geminicli.NewDriveClient()
|
||||
}
|
||||
@@ -106,6 +106,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewOpenAIOAuthClient,
|
||||
NewGeminiOAuthClient,
|
||||
NewGeminiCliCodeAssistClient,
|
||||
NewGeminiDriveClient,
|
||||
|
||||
ProvideEnt,
|
||||
ProvideSQLDB,
|
||||
|
||||
@@ -219,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
|
||||
@@ -372,6 +372,13 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
@@ -382,6 +389,27 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
|
||||
if mapping == nil || model == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := mapping[model]; exists {
|
||||
return
|
||||
}
|
||||
for pattern := range mapping {
|
||||
if matchWildcard(pattern, model) {
|
||||
return
|
||||
}
|
||||
}
|
||||
mapping[model] = model
|
||||
}
|
||||
|
||||
func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []string) {
|
||||
for _, model := range models {
|
||||
ensureAntigravityDefaultPassthrough(mapping, model)
|
||||
}
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
|
||||
66
backend/internal/service/account_intercept_warmup_test.go
Normal file
66
backend/internal/service/account_intercept_warmup_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil credentials",
|
||||
credentials: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
credentials: map[string]any{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field not present",
|
||||
credentials: map[string]any{"access_token": "tok"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is true",
|
||||
credentials: map[string]any{"intercept_warmup_requests": true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "field is false",
|
||||
credentials: map[string]any{"intercept_warmup_requests": false},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is string true",
|
||||
credentials: map[string]any{"intercept_warmup_requests": "true"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is int 1",
|
||||
credentials: map[string]any{"intercept_warmup_requests": 1},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is nil",
|
||||
credentials: map[string]any{"intercept_warmup_requests": nil},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Credentials: tt.credentials}
|
||||
result := a.IsInterceptWarmupEnabled()
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -217,12 +218,20 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
}
|
||||
|
||||
if account.Platform == PlatformGemini {
|
||||
return s.getGeminiUsage(ctx, account)
|
||||
usage, err := s.getGeminiUsage(ctx, account)
|
||||
if err == nil {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
|
||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.getAntigravityUsage(ctx, account)
|
||||
usage, err := s.getAntigravityUsage(ctx, account)
|
||||
if err == nil {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
@@ -256,6 +265,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -486,6 +496,32 @@ func parseTime(s string) (time.Time, error) {
|
||||
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) tryClearRecoverableAccountError(ctx context.Context, account *Account) {
|
||||
if account == nil || account.Status != StatusError {
|
||||
return
|
||||
}
|
||||
|
||||
msg := strings.ToLower(strings.TrimSpace(account.ErrorMessage))
|
||||
if msg == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(msg, "token refresh failed") &&
|
||||
!strings.Contains(msg, "invalid_client") &&
|
||||
!strings.Contains(msg, "missing_project_id") &&
|
||||
!strings.Contains(msg, "unauthenticated") {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
|
||||
log.Printf("[usage] failed to clear recoverable account error for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
}
|
||||
|
||||
// buildUsageInfo 构建UsageInfo
|
||||
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
|
||||
info := &UsageInfo{
|
||||
|
||||
@@ -267,3 +267,50 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3-pro-high": "gemini-3.1-pro-high",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping["gemini-3-flash"] != "gemini-3-flash" {
|
||||
t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"])
|
||||
}
|
||||
if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" {
|
||||
t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"])
|
||||
}
|
||||
if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" {
|
||||
t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3*": "gemini-3.1-pro-high",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mapping := account.GetModelMapping()
|
||||
if _, exists := mapping["gemini-3-flash"]; exists {
|
||||
t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists")
|
||||
}
|
||||
if _, exists := mapping["gemini-3.1-pro-high"]; exists {
|
||||
t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists")
|
||||
}
|
||||
if _, exists := mapping["gemini-3.1-pro-low"]; exists {
|
||||
t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists")
|
||||
}
|
||||
if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" {
|
||||
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ type AdminService interface {
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
||||
@@ -2114,6 +2115,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
|
||||
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||
return
|
||||
|
||||
@@ -87,7 +87,6 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||
)
|
||||
@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||
billingModel := mappedModel
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -1370,6 +1370,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -1618,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Model: billingModel, // 使用映射模型用于计费和日志
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -1972,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if mappedModel == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||||
}
|
||||
billingModel := mappedModel
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -2042,6 +2047,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -2197,7 +2206,7 @@ handleSuccess:
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Model: billingModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -2642,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
defaultDur := s.getDefaultRateLimitDuration()
|
||||
|
||||
// 尝试解析模型 key 并设置模型级限流
|
||||
modelKey := resolveAntigravityModelKey(requestedModel)
|
||||
//
|
||||
// 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6),
|
||||
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
|
||||
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
|
||||
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||
if strings.TrimSpace(modelKey) == "" {
|
||||
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
|
||||
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
|
||||
modelKey = resolveAntigravityModelKey(requestedModel)
|
||||
}
|
||||
if modelKey != "" {
|
||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||||
@@ -3739,14 +3757,17 @@ func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
|
||||
}
|
||||
|
||||
// isImageGenerationModel 判断模型是否为图片生成模型
|
||||
// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
|
||||
// 支持的模型:gemini-3.1-flash-image, gemini-3-pro-image, gemini-2.5-flash-image 等
|
||||
func isImageGenerationModel(model string) bool {
|
||||
modelLower := strings.ToLower(model)
|
||||
// 移除 models/ 前缀
|
||||
modelLower = strings.TrimPrefix(modelLower, "models/")
|
||||
|
||||
// 精确匹配或前缀匹配
|
||||
return modelLower == "gemini-3-pro-image" ||
|
||||
return modelLower == "gemini-3.1-flash-image" ||
|
||||
modelLower == "gemini-3.1-flash-image-preview" ||
|
||||
strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") ||
|
||||
modelLower == "gemini-3-pro-image" ||
|
||||
modelLower == "gemini-3-pro-image-preview" ||
|
||||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
|
||||
modelLower == "gemini-2.5-flash-image" ||
|
||||
@@ -3881,7 +3902,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
@@ -3934,7 +3954,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3975,7 +3995,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
|
||||
@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
type antigravitySettingRepoStub struct{}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
|
||||
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
|
||||
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hello"},
|
||||
},
|
||||
"max_tokens": 16,
|
||||
"stream": true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
const mappedModel = "gemini-3-pro-high"
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "acc-forward-billing",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-5": mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
|
||||
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
const mappedModel = "gemini-3-pro-high"
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "acc-gemini-billing",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-flash": mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||
|
||||
@@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
},
|
||||
|
||||
// 3. 默认映射中的透传(映射到自己)
|
||||
{
|
||||
name: "默认映射透传 - claude-sonnet-4-6",
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-6",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - claude-sonnet-4-5",
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
|
||||
@@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景
|
||||
// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking)
|
||||
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("5s")
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
|
||||
|
||||
require.Nil(t, result)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||
|
||||
@@ -133,6 +133,18 @@ func (s *BillingService) initFallbackPricing() {
|
||||
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 4.6 Opus (与4.5同价)
|
||||
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
|
||||
|
||||
// Gemini 3.1 Pro
|
||||
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-6, // $2 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 2e-6, // $2 per MTok
|
||||
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
@@ -141,6 +153,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
|
||||
// 按模型系列匹配
|
||||
if strings.Contains(modelLower, "opus") {
|
||||
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
|
||||
return s.fallbackPrices["claude-opus-4.6"]
|
||||
}
|
||||
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
||||
return s.fallbackPrices["claude-opus-4.5"]
|
||||
}
|
||||
@@ -158,6 +173,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
}
|
||||
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
|
||||
return s.fallbackPrices["gemini-3.1-pro"]
|
||||
}
|
||||
|
||||
// 默认使用Sonnet价格
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
@@ -525,7 +543,10 @@ func (s *BillingService) getDefaultImagePrice(model string, imageSize string) fl
|
||||
basePrice = 0.134
|
||||
}
|
||||
|
||||
// 4K 尺寸翻倍
|
||||
// 2K 尺寸 1.5 倍,4K 尺寸翻倍
|
||||
if imageSize == "2K" {
|
||||
return basePrice * 1.5
|
||||
}
|
||||
if imageSize == "4K" {
|
||||
return basePrice * 2
|
||||
}
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值
|
||||
|
||||
// 2K 尺寸,默认价格 $0.134
|
||||
// 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001)
|
||||
|
||||
// 多张图片
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
|
||||
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.603, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
|
||||
@@ -63,13 +63,13 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.536, cost.ActualCost, 0.0001)
|
||||
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.804, cost.ActualCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
|
||||
@@ -95,8 +95,8 @@ func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
@@ -127,9 +127,9 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 回退默认价格 $0.134
|
||||
// 2K 回退默认价格 $0.201 (1.5倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 回退默认价格 $0.268 (翻倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
@@ -140,10 +140,10 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
|
||||
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil
|
||||
|
||||
// 1K 和 2K 使用相同的默认价格 $0.134
|
||||
// 1K 默认价格 $0.134,2K 默认价格 $0.201 (1.5倍)
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -262,6 +263,107 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
respBody string
|
||||
wantPassthrough bool
|
||||
}{
|
||||
{
|
||||
name: "404 endpoint not found passes through as 404",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
|
||||
wantPassthrough: true,
|
||||
},
|
||||
{
|
||||
name: "404 generic not found passes through as 404",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
wantPassthrough: true,
|
||||
},
|
||||
{
|
||||
name: "400 Invalid URL does not passthrough",
|
||||
statusCode: http.StatusBadRequest,
|
||||
respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
{
|
||||
name: "400 model error does not passthrough",
|
||||
statusCode: http.StatusBadRequest,
|
||||
respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
{
|
||||
name: "500 internal error does not passthrough",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
respBody: `{"error":{"message":"internal error","type":"api_error"}}`,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-5-20250929","messages":[{"role":"user","content":"hi"}]}`)
|
||||
parsed := &ParsedRequest{Body: body, Model: "claude-sonnet-4-5-20250929"}
|
||||
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: tt.statusCode,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(tt.respBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Name: "proxy-acc",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-proxy",
|
||||
"base_url": "https://proxy.example.com",
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
|
||||
if tt.wantPassthrough {
|
||||
// 返回 nil(不记录为错误),HTTP 状态码 404 + Anthropic 错误体
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||||
var errResp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &errResp))
|
||||
require.Equal(t, "error", errResp["type"])
|
||||
errObj, ok := errResp["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "not_found_error", errObj["type"])
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Equal(t, tt.statusCode, rec.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -22,60 +24,78 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||
}
|
||||
|
||||
func TestStripBetaToken(t *testing.T) {
|
||||
func TestStripBetaTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
token string
|
||||
tokens []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "token in middle",
|
||||
name: "single token in middle",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "token at start",
|
||||
name: "single token at start",
|
||||
header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "token at end",
|
||||
name: "single token at end",
|
||||
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "token not present",
|
||||
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "empty header",
|
||||
header: "",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "only token",
|
||||
header: "context-1m-2025-08-07",
|
||||
token: "context-1m-2025-08-07",
|
||||
tokens: []string{"context-1m-2025-08-07"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil tokens",
|
||||
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
tokens: nil,
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "multiple tokens removed",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14,fast-mode-2026-02-01",
|
||||
tokens: []string{"context-1m-2025-08-07", "fast-mode-2026-02-01"},
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "DroppedBetas removes both context-1m and fast-mode",
|
||||
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
|
||||
tokens: claude.DroppedBetas,
|
||||
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripBetaToken(tt.header, tt.token)
|
||||
got := stripBetaTokens(tt.header, tt.tokens)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@@ -90,3 +110,29 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||
}
|
||||
|
||||
func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
|
||||
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
|
||||
incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
|
||||
drop := droppedBetaSet()
|
||||
|
||||
got := mergeAnthropicBetaDropping(required, incoming, drop)
|
||||
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
|
||||
require.NotContains(t, got, "context-1m-2025-08-07")
|
||||
require.NotContains(t, got, "fast-mode-2026-02-01")
|
||||
}
|
||||
|
||||
func TestDroppedBetaSet(t *testing.T) {
|
||||
// Base set contains DroppedBetas
|
||||
base := droppedBetaSet()
|
||||
require.Contains(t, base, claude.BetaContext1M)
|
||||
require.Contains(t, base, claude.BetaFastMode)
|
||||
require.Len(t, base, len(claude.DroppedBetas))
|
||||
|
||||
// With extra tokens
|
||||
extended := droppedBetaSet(claude.BetaClaudeCode)
|
||||
require.Contains(t, extended, claude.BetaContext1M)
|
||||
require.Contains(t, extended, claude.BetaFastMode)
|
||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||
}
|
||||
|
||||
@@ -895,6 +895,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Priority: 1,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Priority: 2,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}},
|
||||
},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号")
|
||||
|
||||
acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "supporting model")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(50)
|
||||
@@ -1070,6 +1119,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
model: "gemini-2.5-pro",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -470,7 +470,7 @@ type ForwardResult struct {
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
|
||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||
// 图片生成计费字段(图片生成模型使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||
|
||||
@@ -2825,10 +2825,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
}
|
||||
// Gemini API Key 账户直接透传,由上游判断模型是否支持
|
||||
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
|
||||
return true
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
@@ -4429,12 +4425,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// messages requests typically use only oauth + interleaved-thinking.
|
||||
// Also drop claude-code beta if a downstream client added it.
|
||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
|
||||
drop := droppedBetaSet(claude.BetaClaudeCode)
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||
} else {
|
||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||
req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
|
||||
req.Header.Set("anthropic-beta", stripBetaTokens(s.getBetaHeader(modelID, clientBetaHeader), claude.DroppedBetas))
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
@@ -4588,23 +4584,45 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
// stripBetaToken removes a single beta token from a comma-separated header value.
|
||||
// It short-circuits when the token is not present to avoid unnecessary allocations.
|
||||
func stripBetaToken(header, token string) string {
|
||||
if !strings.Contains(header, token) {
|
||||
// stripBetaTokens removes the given beta tokens from a comma-separated header value.
|
||||
func stripBetaTokens(header string, tokens []string) string {
|
||||
if header == "" || len(tokens) == 0 {
|
||||
return header
|
||||
}
|
||||
out := make([]string, 0, 8)
|
||||
for _, p := range strings.Split(header, ",") {
|
||||
drop := make(map[string]struct{}, len(tokens))
|
||||
for _, t := range tokens {
|
||||
drop[t] = struct{}{}
|
||||
}
|
||||
parts := strings.Split(header, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" || p == token {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := drop[p]; ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
if len(out) == len(parts) {
|
||||
return header // no change, avoid allocation
|
||||
}
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
||||
func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(claude.DroppedBetas)+len(extra))
|
||||
for _, t := range claude.DroppedBetas {
|
||||
m[t] = struct{}{}
|
||||
}
|
||||
for _, t := range extra {
|
||||
m[t] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
|
||||
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
|
||||
// headers when using Claude Code-scoped OAuth credentials.
|
||||
@@ -5997,9 +6015,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
|
||||
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
|
||||
if account.Platform == PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6203,6 +6222,17 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
// 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。
|
||||
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
logger.LegacyPrintf("service.gateway",
|
||||
"[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s",
|
||||
account.ID, account.Name, truncateString(upstreamMsg, 512))
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream")
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
@@ -6379,7 +6409,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
incomingBeta := req.Header.Get("anthropic-beta")
|
||||
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
|
||||
drop := map[string]struct{}{claude.BetaContext1M: {}}
|
||||
drop := droppedBetaSet()
|
||||
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
|
||||
} else {
|
||||
clientBetaHeader := req.Header.Get("anthropic-beta")
|
||||
@@ -6390,7 +6420,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
||||
beta = beta + "," + claude.BetaTokenCounting
|
||||
}
|
||||
req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
|
||||
req.Header.Set("anthropic-beta", stripBetaTokens(beta, claude.DroppedBetas))
|
||||
}
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
|
||||
@@ -54,6 +54,7 @@ type GeminiOAuthService struct {
|
||||
proxyRepo ProxyRepository
|
||||
oauthClient GeminiOAuthClient
|
||||
codeAssist GeminiCliCodeAssistClient
|
||||
driveClient geminicli.DriveClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
@@ -66,6 +67,7 @@ func NewGeminiOAuthService(
|
||||
proxyRepo ProxyRepository,
|
||||
oauthClient GeminiOAuthClient,
|
||||
codeAssist GeminiCliCodeAssistClient,
|
||||
driveClient geminicli.DriveClient,
|
||||
cfg *config.Config,
|
||||
) *GeminiOAuthService {
|
||||
return &GeminiOAuthService{
|
||||
@@ -73,6 +75,7 @@ func NewGeminiOAuthService(
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
codeAssist: codeAssist,
|
||||
driveClient: driveClient,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -362,9 +365,8 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken
|
||||
|
||||
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...")
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Check if it's a 403 (scope not granted)
|
||||
if strings.Contains(err.Error(), "status 403") {
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
|
||||
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
@@ -487,7 +487,7 @@ func TestIsNonRetryableGeminiOAuthError(t *testing.T) {
|
||||
func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
t.Run("完整字段", func(t *testing.T) {
|
||||
@@ -687,7 +687,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
|
||||
defer svc.Stop()
|
||||
|
||||
result := svc.GetOAuthConfig()
|
||||
@@ -709,7 +709,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
|
||||
func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
|
||||
// 调用 Stop 不应 panic
|
||||
svc.Stop()
|
||||
@@ -806,6 +806,18 @@ func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context,
|
||||
panic("not impl")
|
||||
}
|
||||
|
||||
// mockDriveClient implements geminicli.DriveClient for tests.
|
||||
type mockDriveClient struct {
|
||||
getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error)
|
||||
}
|
||||
|
||||
func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) {
|
||||
if m.getStorageQuotaFunc != nil {
|
||||
return m.getStorageQuotaFunc(ctx, accessToken, proxyURL)
|
||||
}
|
||||
return nil, fmt.Errorf("drive API not available in test")
|
||||
}
|
||||
|
||||
// =====================
|
||||
// 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑)
|
||||
// =====================
|
||||
@@ -825,7 +837,7 @@ func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "")
|
||||
@@ -852,7 +864,7 @@ func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "")
|
||||
@@ -881,7 +893,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "")
|
||||
@@ -903,7 +915,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
|
||||
func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -923,7 +935,7 @@ func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
|
||||
func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -958,7 +970,7 @@ func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -997,7 +1009,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *test
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1042,7 +1054,7 @@ func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
// 无 oauth_type 凭据的旧账号
|
||||
@@ -1090,7 +1102,7 @@ func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
proxyID := int64(5)
|
||||
@@ -1132,7 +1144,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetec
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1181,7 +1193,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpt
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1214,7 +1226,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1254,7 +1266,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1308,7 +1320,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *t
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg)
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg)
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1341,7 +1353,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
|
||||
}
|
||||
|
||||
// 无自定义 OAuth 客户端,无法 fallback
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
@@ -1370,7 +1382,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
|
||||
func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{
|
||||
@@ -1389,7 +1401,7 @@ func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
|
||||
func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
// 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝)
|
||||
@@ -1416,7 +1428,7 @@ func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
|
||||
func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("test-session", &geminicli.OAuthSession{
|
||||
|
||||
@@ -107,12 +107,12 @@ func TestIsModelRateLimited(t *testing.T) {
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3.1-pro-high",
|
||||
name: "antigravity platform - gemini-3-pro-preview mapped to gemini-3-pro-high",
|
||||
account: &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"gemini-3.1-pro-high": map[string]any{
|
||||
"gemini-3-pro-high": map[string]any{
|
||||
"rate_limit_reset_at": future,
|
||||
},
|
||||
},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,515 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ----------
|
||||
|
||||
// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中
|
||||
// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。
|
||||
func testParseUploadOrCreateTaskID(respBody []byte) (string, error) {
|
||||
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||
if id == "" {
|
||||
return "", assert.AnError // 占位错误,表示 "missing id"
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。
|
||||
func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) {
|
||||
var found *SoraImageTaskStatus
|
||||
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("id").String() != taskID {
|
||||
return true // continue
|
||||
}
|
||||
status := strings.TrimSpace(item.Get("status").String())
|
||||
progress := item.Get("progress_pct").Float()
|
||||
var urls []string
|
||||
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
|
||||
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
|
||||
urls = append(urls, u)
|
||||
}
|
||||
return true
|
||||
})
|
||||
found = &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
ProgressPct: progress,
|
||||
URLs: urls,
|
||||
}
|
||||
return false // break
|
||||
})
|
||||
if found != nil {
|
||||
return found, true
|
||||
}
|
||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false
|
||||
}
|
||||
|
||||
// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。
|
||||
func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||||
pendingResult := gjson.ParseBytes(respBody)
|
||||
if !pendingResult.IsArray() {
|
||||
return nil, false
|
||||
}
|
||||
var pendingFound *SoraVideoTaskStatus
|
||||
pendingResult.ForEach(func(_, task gjson.Result) bool {
|
||||
if task.Get("id").String() != taskID {
|
||||
return true
|
||||
}
|
||||
progress := 0
|
||||
if v := task.Get("progress_pct"); v.Exists() {
|
||||
progress = int(v.Float() * 100)
|
||||
}
|
||||
status := strings.TrimSpace(task.Get("status").String())
|
||||
pendingFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
ProgressPct: progress,
|
||||
}
|
||||
return false
|
||||
})
|
||||
if pendingFound != nil {
|
||||
return pendingFound, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。
|
||||
func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||||
var draftFound *SoraVideoTaskStatus
|
||||
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
|
||||
if draft.Get("task_id").String() != taskID {
|
||||
return true
|
||||
}
|
||||
kind := strings.TrimSpace(draft.Get("kind").String())
|
||||
reason := strings.TrimSpace(draft.Get("reason_str").String())
|
||||
if reason == "" {
|
||||
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
|
||||
}
|
||||
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
|
||||
if urlStr == "" {
|
||||
urlStr = strings.TrimSpace(draft.Get("url").String())
|
||||
}
|
||||
|
||||
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
||||
msg := reason
|
||||
if msg == "" {
|
||||
msg = "Content violates guardrails"
|
||||
}
|
||||
draftFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: msg,
|
||||
}
|
||||
} else {
|
||||
draftFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
URLs: []string{urlStr},
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
if draftFound != nil {
|
||||
return draftFound, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ===================== Test 1: TestSoraParseUploadResponse =====================
|
||||
|
||||
func TestSoraParseUploadResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常 id",
|
||||
body: `{"id":"file-abc123","status":"uploaded"}`,
|
||||
wantID: "file-abc123",
|
||||
},
|
||||
{
|
||||
name: "空 id",
|
||||
body: `{"id":"","status":"uploaded"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无 id 字段",
|
||||
body: `{"status":"uploaded"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "id 全为空白",
|
||||
body: `{"id":" ","status":"uploaded"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "id 前后有空白",
|
||||
body: `{"id":" file-trimmed ","status":"uploaded"}`,
|
||||
wantID: "file-trimmed",
|
||||
},
|
||||
{
|
||||
name: "空 JSON 对象",
|
||||
body: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, "应返回错误")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantID, id)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 2: TestSoraParseCreateTaskResponse =====================
|
||||
|
||||
func TestSoraParseCreateTaskResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常任务 id",
|
||||
body: `{"id":"task-123"}`,
|
||||
wantID: "task-123",
|
||||
},
|
||||
{
|
||||
name: "缺失 id",
|
||||
body: `{"status":"created"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "空 id",
|
||||
body: `{"id":" "}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "id 为数字(gjson 转字符串)",
|
||||
body: `{"id":123}`,
|
||||
wantID: "123",
|
||||
},
|
||||
{
|
||||
name: "id 含特殊字符",
|
||||
body: `{"id":"task-abc-def-456-ghi"}`,
|
||||
wantID: "task-abc-def-456-ghi",
|
||||
},
|
||||
{
|
||||
name: "额外字段不影响解析",
|
||||
body: `{"id":"task-999","type":"image_gen","extra":"data"}`,
|
||||
wantID: "task-999",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, "应返回错误")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantID, id)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 3: TestSoraParseFetchRecentImageTask =====================
|
||||
|
||||
func TestSoraParseFetchRecentImageTask(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
taskID string
|
||||
wantFound bool
|
||||
wantStatus string
|
||||
wantProgress float64
|
||||
wantURLs []string
|
||||
}{
|
||||
{
|
||||
name: "匹配已完成任务",
|
||||
body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantProgress: 1.0,
|
||||
wantURLs: []string{"https://example.com/img.png"},
|
||||
},
|
||||
{
|
||||
name: "匹配处理中任务",
|
||||
body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`,
|
||||
taskID: "task-2",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 0.5,
|
||||
wantURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "无匹配任务",
|
||||
body: `{"task_responses":[{"id":"other","status":"completed"}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
wantStatus: "processing",
|
||||
},
|
||||
{
|
||||
name: "空 task_responses",
|
||||
body: `{"task_responses":[]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
wantStatus: "processing",
|
||||
},
|
||||
{
|
||||
name: "缺少 task_responses 字段",
|
||||
body: `{"other":"data"}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
wantStatus: "processing",
|
||||
},
|
||||
{
|
||||
name: "多个任务中精准匹配",
|
||||
body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`,
|
||||
taskID: "task-b",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 0.3,
|
||||
wantURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "多个 generations",
|
||||
body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`,
|
||||
taskID: "task-m",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantProgress: 1.0,
|
||||
wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID)
|
||||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, tt.taskID, status.ID)
|
||||
require.Equal(t, tt.wantStatus, status.Status)
|
||||
if tt.wantFound {
|
||||
require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配")
|
||||
require.Equal(t, tt.wantURLs, status.URLs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 4: TestSoraParseGetVideoTaskPending =====================
|
||||
|
||||
func TestSoraParseGetVideoTaskPending(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
taskID string
|
||||
wantFound bool
|
||||
wantStatus string
|
||||
wantProgress int
|
||||
}{
|
||||
{
|
||||
name: "匹配 pending 任务",
|
||||
body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`,
|
||||
taskID: "task-1",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 50,
|
||||
},
|
||||
{
|
||||
name: "进度为 0",
|
||||
body: `[{"id":"task-2","status":"queued","progress_pct":0}]`,
|
||||
taskID: "task-2",
|
||||
wantFound: true,
|
||||
wantStatus: "queued",
|
||||
wantProgress: 0,
|
||||
},
|
||||
{
|
||||
name: "进度为 1(100%)",
|
||||
body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`,
|
||||
taskID: "task-3",
|
||||
wantFound: true,
|
||||
wantStatus: "completing",
|
||||
wantProgress: 100,
|
||||
},
|
||||
{
|
||||
name: "空数组",
|
||||
body: `[]`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "无匹配 id",
|
||||
body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "多个任务精准匹配",
|
||||
body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`,
|
||||
taskID: "task-c",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 80,
|
||||
},
|
||||
{
|
||||
name: "非数组 JSON",
|
||||
body: `{"id":"task-1","status":"processing"}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "无 progress_pct 字段",
|
||||
body: `[{"id":"task-4","status":"pending"}]`,
|
||||
taskID: "task-4",
|
||||
wantFound: true,
|
||||
wantStatus: "pending",
|
||||
wantProgress: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID)
|
||||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||
if tt.wantFound {
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, tt.taskID, status.ID)
|
||||
require.Equal(t, tt.wantStatus, status.Status)
|
||||
require.Equal(t, tt.wantProgress, status.ProgressPct)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 5: TestSoraParseGetVideoTaskDrafts =====================
|
||||
|
||||
func TestSoraParseGetVideoTaskDrafts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
taskID string
|
||||
wantFound bool
|
||||
wantStatus string
|
||||
wantURLs []string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "正常完成的视频",
|
||||
body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantURLs: []string{"https://example.com/video.mp4"},
|
||||
},
|
||||
{
|
||||
name: "使用 url 字段回退",
|
||||
body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`,
|
||||
taskID: "task-2",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantURLs: []string{"https://example.com/fallback.mp4"},
|
||||
},
|
||||
{
|
||||
name: "内容违规",
|
||||
body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`,
|
||||
taskID: "task-3",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Content policy violation",
|
||||
},
|
||||
{
|
||||
name: "内容违规 - markdown_reason_str 回退",
|
||||
body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`,
|
||||
taskID: "task-4",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Markdown reason",
|
||||
},
|
||||
{
|
||||
name: "内容违规 - 无 reason 使用默认消息",
|
||||
body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`,
|
||||
taskID: "task-5",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Content violates guardrails",
|
||||
},
|
||||
{
|
||||
name: "有 reason_str 但非 violation kind(仍判定失败)",
|
||||
body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`,
|
||||
taskID: "task-6",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Some error occurred",
|
||||
},
|
||||
{
|
||||
name: "空 URL 判定为失败",
|
||||
body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`,
|
||||
taskID: "task-7",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Content violates guardrails",
|
||||
},
|
||||
{
|
||||
name: "无匹配 task_id",
|
||||
body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "空 items",
|
||||
body: `{"items":[]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "缺少 items 字段",
|
||||
body: `{"other":"data"}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "多个 items 精准匹配",
|
||||
body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`,
|
||||
taskID: "task-b",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Bad content",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID)
|
||||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||
if !tt.wantFound {
|
||||
return
|
||||
}
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, tt.taskID, status.ID)
|
||||
require.Equal(t, tt.wantStatus, status.Status)
|
||||
if tt.wantErr != "" {
|
||||
require.Equal(t, tt.wantErr, status.ErrorMsg)
|
||||
}
|
||||
if tt.wantURLs != nil {
|
||||
require.Equal(t, tt.wantURLs, status.URLs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,260 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
|
||||
|
||||
type soraCurlCFFISidecarRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers,omitempty"`
|
||||
BodyBase64 string `json:"body_base64,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Impersonate string `json:"impersonate,omitempty"`
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type soraCurlCFFISidecarResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status int `json:"status"`
|
||||
Headers map[string]any `json:"headers"`
|
||||
BodyBase64 string `json:"body_base64"`
|
||||
Body string `json:"body"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
if req == nil || req.URL == nil {
|
||||
return nil, errors.New("request url is nil")
|
||||
}
|
||||
if c == nil || c.cfg == nil {
|
||||
return nil, errors.New("sora curl_cffi sidecar config is nil")
|
||||
}
|
||||
if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return nil, errors.New("sora curl_cffi sidecar is disabled")
|
||||
}
|
||||
endpoint := c.curlCFFISidecarEndpoint()
|
||||
if endpoint == "" {
|
||||
return nil, errors.New("sora curl_cffi sidecar base_url is empty")
|
||||
}
|
||||
|
||||
bodyBytes, err := readAndRestoreRequestBody(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
|
||||
}
|
||||
|
||||
headers := make(map[string][]string, len(req.Header)+1)
|
||||
for key, vals := range req.Header {
|
||||
copied := make([]string, len(vals))
|
||||
copy(copied, vals)
|
||||
headers[key] = copied
|
||||
}
|
||||
if strings.TrimSpace(req.Host) != "" {
|
||||
if _, ok := headers["Host"]; !ok {
|
||||
headers["Host"] = []string{req.Host}
|
||||
}
|
||||
}
|
||||
|
||||
payload := soraCurlCFFISidecarRequest{
|
||||
Method: req.Method,
|
||||
URL: req.URL.String(),
|
||||
Headers: headers,
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
SessionKey: c.sidecarSessionKey(account, proxyURL),
|
||||
Impersonate: c.curlCFFIImpersonate(),
|
||||
TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
|
||||
}
|
||||
if len(bodyBytes) > 0 {
|
||||
payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
|
||||
}
|
||||
sidecarReq.Header.Set("Content-Type", "application/json")
|
||||
sidecarReq.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
|
||||
sidecarResp, err := httpClient.Do(sidecarReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = sidecarResp.Body.Close()
|
||||
}()
|
||||
|
||||
sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
|
||||
}
|
||||
if sidecarResp.StatusCode != http.StatusOK {
|
||||
redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
|
||||
}
|
||||
|
||||
var payloadResp soraCurlCFFISidecarResponse
|
||||
if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
|
||||
}
|
||||
if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
|
||||
}
|
||||
statusCode := payloadResp.StatusCode
|
||||
if statusCode <= 0 {
|
||||
statusCode = payloadResp.Status
|
||||
}
|
||||
if statusCode <= 0 {
|
||||
return nil, errors.New("sora curl_cffi sidecar response missing status code")
|
||||
}
|
||||
|
||||
responseBody := []byte(payloadResp.Body)
|
||||
if strings.TrimSpace(payloadResp.BodyBase64) != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
|
||||
}
|
||||
responseBody = decoded
|
||||
}
|
||||
|
||||
respHeaders := make(http.Header)
|
||||
for key, rawVal := range payloadResp.Headers {
|
||||
for _, v := range convertSidecarHeaderValue(rawVal) {
|
||||
respHeaders.Add(key, v)
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Header: respHeaders,
|
||||
Body: io.NopCloser(bytes.NewReader(responseBody)),
|
||||
ContentLength: int64(len(responseBody)),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
|
||||
if c == nil || c.cfg == nil {
|
||||
return ""
|
||||
}
|
||||
raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
|
||||
return raw
|
||||
}
|
||||
if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
|
||||
parsed.Path = "/request"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||
}
|
||||
timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
|
||||
if timeoutSeconds <= 0 {
|
||||
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||
}
|
||||
return timeoutSeconds
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFIImpersonate() string {
|
||||
if c == nil || c.cfg == nil {
|
||||
return "chrome131"
|
||||
}
|
||||
impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
|
||||
if impersonate == "" {
|
||||
return "chrome131"
|
||||
}
|
||||
return impersonate
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 3600
|
||||
}
|
||||
ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
|
||||
if ttl < 0 {
|
||||
return 3600
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func convertSidecarHeaderValue(raw any) []string {
|
||||
switch val := raw.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
if strings.TrimSpace(val) == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{val}
|
||||
case []any:
|
||||
out := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
s := strings.TrimSpace(fmt.Sprint(item))
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
out := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
if strings.TrimSpace(item) != "" {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
s := strings.TrimSpace(fmt.Sprint(val))
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"math/rand"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -669,7 +671,7 @@ func processSoraCharacterUsername(usernameHint string) string {
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
|
||||
return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
|
||||
@@ -829,7 +831,7 @@ func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content stri
|
||||
},
|
||||
},
|
||||
}
|
||||
encoded, _ := json.Marshal(chunk)
|
||||
encoded, _ := jsonMarshalRaw(chunk)
|
||||
if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -850,7 +852,7 @@ func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content stri
|
||||
},
|
||||
},
|
||||
}
|
||||
finalEncoded, _ := json.Marshal(finalChunk)
|
||||
finalEncoded, _ := jsonMarshalRaw(finalChunk)
|
||||
if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
|
||||
return &ms, err
|
||||
}
|
||||
@@ -1051,6 +1053,23 @@ func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
|
||||
return output
|
||||
}
|
||||
|
||||
// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
|
||||
// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
|
||||
func jsonMarshalRaw(v any) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Encode 会追加换行符,去掉它
|
||||
b := buf.Bytes()
|
||||
if len(b) > 0 && b[len(b)-1] == '\n' {
|
||||
b = b[:len(b)-1]
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func buildSoraContent(mediaType string, urls []string) string {
|
||||
switch mediaType {
|
||||
case "image":
|
||||
|
||||
@@ -316,7 +316,7 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin
|
||||
}
|
||||
}
|
||||
|
||||
updatedData, err := json.Marshal(payload)
|
||||
updatedData, err := jsonMarshalRaw(payload)
|
||||
if err != nil {
|
||||
return "data: " + data, contentDelta, nil
|
||||
}
|
||||
@@ -484,7 +484,7 @@ func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel
|
||||
if originalModel != "" {
|
||||
payload["model"] = originalModel
|
||||
}
|
||||
updatedData, err := json.Marshal(payload)
|
||||
updatedData, err := jsonMarshalRaw(payload)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawU
|
||||
return relative, nil
|
||||
}
|
||||
if s.debug {
|
||||
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err)
|
||||
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err)
|
||||
}
|
||||
if attempt < retries {
|
||||
time.Sleep(time.Duration(attempt*attempt) * time.Second)
|
||||
@@ -252,7 +252,7 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
||||
|
||||
relative := path.Join("/", mediaType, datePath, filename)
|
||||
if s.debug {
|
||||
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative)
|
||||
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative)
|
||||
}
|
||||
return relative, nil
|
||||
}
|
||||
@@ -305,3 +305,19 @@ func removePartialDownload(root *os.Root, filePath string) {
|
||||
}
|
||||
_ = root.Remove(filePath)
|
||||
}
|
||||
|
||||
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
|
||||
func sanitizeMediaLogURL(rawURL string) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
if len(rawURL) > 80 {
|
||||
return rawURL[:80] + "..."
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
safe := parsed.Scheme + "://" + parsed.Host + parsed.Path
|
||||
if len(safe) > 120 {
|
||||
return safe[:120] + "..."
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type soraChallengeCooldownEntry struct {
|
||||
Until time.Time
|
||||
StatusCode int
|
||||
CFRay string
|
||||
ConsecutiveChallenges int
|
||||
LastChallengeAt time.Time
|
||||
}
|
||||
|
||||
type soraSidecarSessionEntry struct {
|
||||
SessionKey string
|
||||
ExpiresAt time.Time
|
||||
LastUsedAt time.Time
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 900
|
||||
}
|
||||
cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
|
||||
if cooldown <= 0 {
|
||||
return 0
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return nil
|
||||
}
|
||||
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||
if cooldownSeconds <= 0 {
|
||||
return nil
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
|
||||
c.challengeCooldownMu.RLock()
|
||||
entry, ok := c.challengeCooldowns[key]
|
||||
c.challengeCooldownMu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if !entry.Until.After(now) {
|
||||
c.challengeCooldownMu.Lock()
|
||||
delete(c.challengeCooldowns, key)
|
||||
c.challengeCooldownMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
|
||||
if remaining < 1 {
|
||||
remaining = 1
|
||||
}
|
||||
message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
|
||||
if entry.ConsecutiveChallenges > 1 {
|
||||
message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
|
||||
}
|
||||
if entry.CFRay != "" {
|
||||
message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: message,
|
||||
Headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||
if cooldownSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
cfRay := soraerror.ExtractCloudflareRayID(headers, body)
|
||||
|
||||
c.challengeCooldownMu.Lock()
|
||||
c.cleanupExpiredChallengeCooldownsLocked(now)
|
||||
|
||||
streak := 1
|
||||
existing, ok := c.challengeCooldowns[key]
|
||||
if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
|
||||
streak = existing.ConsecutiveChallenges + 1
|
||||
}
|
||||
effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
|
||||
until := now.Add(time.Duration(effectiveCooldown) * time.Second)
|
||||
if ok && existing.Until.After(until) {
|
||||
until = existing.Until
|
||||
if existing.ConsecutiveChallenges > streak {
|
||||
streak = existing.ConsecutiveChallenges
|
||||
}
|
||||
if cfRay == "" {
|
||||
cfRay = existing.CFRay
|
||||
}
|
||||
}
|
||||
c.challengeCooldowns[key] = soraChallengeCooldownEntry{
|
||||
Until: until,
|
||||
StatusCode: statusCode,
|
||||
CFRay: cfRay,
|
||||
ConsecutiveChallenges: streak,
|
||||
LastChallengeAt: now,
|
||||
}
|
||||
c.challengeCooldownMu.Unlock()
|
||||
|
||||
if c.debugEnabled() {
|
||||
remain := int(math.Ceil(until.Sub(now).Seconds()))
|
||||
if remain < 0 {
|
||||
remain = 0
|
||||
}
|
||||
c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
|
||||
}
|
||||
}
|
||||
|
||||
func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
|
||||
if baseSeconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
if streak < 1 {
|
||||
streak = 1
|
||||
}
|
||||
multiplier := streak
|
||||
if multiplier > 4 {
|
||||
multiplier = 4
|
||||
}
|
||||
cooldown := baseSeconds * multiplier
|
||||
if cooldown > 3600 {
|
||||
cooldown = 3600
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
c.challengeCooldownMu.Lock()
|
||||
_, existed := c.challengeCooldowns[key]
|
||||
if existed {
|
||||
delete(c.challengeCooldowns, key)
|
||||
}
|
||||
c.challengeCooldownMu.Unlock()
|
||||
if existed && c.debugEnabled() {
|
||||
c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
|
||||
if c == nil || !c.sidecarSessionReuseEnabled() {
|
||||
return ""
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return ""
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
ttlSeconds := c.sidecarSessionTTLSeconds()
|
||||
|
||||
c.sidecarSessionMu.Lock()
|
||||
defer c.sidecarSessionMu.Unlock()
|
||||
c.cleanupExpiredSidecarSessionsLocked(now)
|
||||
if existing, exists := c.sidecarSessions[key]; exists {
|
||||
existing.LastUsedAt = now
|
||||
c.sidecarSessions[key] = existing
|
||||
return existing.SessionKey
|
||||
}
|
||||
|
||||
expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
|
||||
if ttlSeconds <= 0 {
|
||||
expiresAt = now.Add(365 * 24 * time.Hour)
|
||||
}
|
||||
newEntry := soraSidecarSessionEntry{
|
||||
SessionKey: "sora-" + uuid.NewString(),
|
||||
ExpiresAt: expiresAt,
|
||||
LastUsedAt: now,
|
||||
}
|
||||
c.sidecarSessions[key] = newEntry
|
||||
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
|
||||
}
|
||||
return newEntry.SessionKey
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
|
||||
if c == nil || len(c.challengeCooldowns) == 0 {
|
||||
return
|
||||
}
|
||||
for key, entry := range c.challengeCooldowns {
|
||||
if !entry.Until.After(now) {
|
||||
delete(c.challengeCooldowns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
|
||||
if c == nil || len(c.sidecarSessions) == 0 {
|
||||
return
|
||||
}
|
||||
for key, entry := range c.sidecarSessions {
|
||||
if !entry.ExpiresAt.After(now) {
|
||||
delete(c.sidecarSessions, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func soraAccountProxyKey(account *Account, proxyURL string) string {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
|
||||
}
|
||||
|
||||
func normalizeSoraProxyKey(proxyURL string) string {
|
||||
raw := strings.TrimSpace(proxyURL)
|
||||
if raw == "" {
|
||||
return "direct"
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
port := strings.TrimSpace(parsed.Port())
|
||||
if host == "" {
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
|
||||
port = ""
|
||||
}
|
||||
if port != "" {
|
||||
host = host + ":" + port
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "proxy"
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
808
backend/internal/service/sora_sdk_client.go
Normal file
808
backend/internal/service/sora_sdk_client.go
Normal file
@@ -0,0 +1,808 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DouDOU-start/go-sora2api/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// SoraSDKClient 基于 go-sora2api SDK 的 Sora 客户端实现。
|
||||
// 它实现了 SoraClient 接口,用 SDK 替代原有的自建 HTTP/PoW/TLS 指纹逻辑。
|
||||
type SoraSDKClient struct {
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
tokenProvider *OpenAITokenProvider
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository
|
||||
|
||||
// 每个 proxyURL 对应一个 SDK 客户端实例
|
||||
sdkClients sync.Map // key: proxyURL (string), value: *sora.Client
|
||||
}
|
||||
|
||||
// NewSoraSDKClient 创建基于 SDK 的 Sora 客户端
|
||||
func NewSoraSDKClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraSDKClient {
|
||||
return &SoraSDKClient{
|
||||
cfg: cfg,
|
||||
httpUpstream: httpUpstream,
|
||||
tokenProvider: tokenProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// SetAccountRepositories 设置账号和 Sora 扩展仓库(用于 token 持久化)
|
||||
func (c *SoraSDKClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.accountRepo = accountRepo
|
||||
c.soraAccountRepo = soraAccountRepo
|
||||
}
|
||||
|
||||
// Enabled 判断是否启用 Sora
|
||||
func (c *SoraSDKClient) Enabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
|
||||
}
|
||||
|
||||
// PreflightCheck 在创建任务前执行账号能力预检。
|
||||
// 当前仅对视频模型执行预检,用于提前识别额度耗尽或能力缺失。
|
||||
func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
|
||||
if modelCfg.Type != "video" {
|
||||
return nil
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
balance, err := sdkClient.GetCreditBalance(ctx, token)
|
||||
if err != nil {
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "当前账号未开通 Sora2 能力或无可用配额",
|
||||
}
|
||||
}
|
||||
if balance.RateLimitReached || balance.RemainingCount <= 0 {
|
||||
msg := "当前账号 Sora2 可用配额不足"
|
||||
if requestedModel != "" {
|
||||
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty image data")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if filename == "" {
|
||||
filename = "image.png"
|
||||
}
|
||||
mediaID, err := sdkClient.UploadImage(ctx, token, data, filename)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return mediaID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
var taskID string
|
||||
if strings.TrimSpace(req.MediaID) != "" {
|
||||
taskID, err = sdkClient.CreateImageTaskWithImage(ctx, token, sentinel, req.Prompt, req.Width, req.Height, req.MediaID)
|
||||
} else {
|
||||
taskID, err = sdkClient.CreateImageTask(ctx, token, sentinel, req.Prompt, req.Width, req.Height)
|
||||
}
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
|
||||
orientation := req.Orientation
|
||||
if orientation == "" {
|
||||
orientation = "landscape"
|
||||
}
|
||||
nFrames := req.Frames
|
||||
if nFrames <= 0 {
|
||||
nFrames = 450
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "sy_8"
|
||||
}
|
||||
size := req.Size
|
||||
if size == "" {
|
||||
size = "small"
|
||||
}
|
||||
|
||||
// Remix 模式
|
||||
if strings.TrimSpace(req.RemixTargetID) != "" {
|
||||
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
|
||||
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
// 普通视频(文生视频或图生视频)
|
||||
taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
|
||||
orientation := req.Orientation
|
||||
if orientation == "" {
|
||||
orientation = "landscape"
|
||||
}
|
||||
nFrames := req.Frames
|
||||
if nFrames <= 0 {
|
||||
nFrames = 450
|
||||
}
|
||||
|
||||
taskID, err := sdkClient.CreateStoryboardTask(ctx, token, sentinel, req.Prompt, orientation, nFrames, req.MediaID, "")
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty video data")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cameoID, err := sdkClient.UploadCharacterVideo(ctx, token, data)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return cameoID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status, err := sdkClient.GetCameoStatus(ctx, token, cameoID)
|
||||
if err != nil {
|
||||
return nil, c.wrapSDKError(err, account)
|
||||
}
|
||||
return &SoraCameoStatus{
|
||||
Status: status.Status,
|
||||
DisplayNameHint: status.DisplayNameHint,
|
||||
UsernameHint: status.UsernameHint,
|
||||
ProfileAssetURL: status.ProfileAssetURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := sdkClient.DownloadCharacterImage(ctx, imageURL)
|
||||
if err != nil {
|
||||
return nil, c.wrapSDKError(err, account)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty character image")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
assetPointer, err := sdkClient.UploadCharacterImage(ctx, token, data)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return assetPointer, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
characterID, err := sdkClient.FinalizeCharacter(ctx, token, req.CameoID, req.Username, req.DisplayName, req.ProfileAssetPointer)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return characterID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sdkClient.SetCharacterPublic(ctx, token, cameoID); err != nil {
|
||||
return c.wrapSDKError(err, account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sdkClient.DeleteCharacter(ctx, token, characterID); err != nil {
|
||||
return c.wrapSDKError(err, account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
postID, err := sdkClient.PublishVideo(ctx, token, sentinel, generationID)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return postID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sdkClient.DeletePost(ctx, token, postID); err != nil {
|
||||
return c.wrapSDKError(err, account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWatermarkFreeURLCustom 使用自定义第三方解析服务获取去水印链接。
|
||||
// SDK 不涉及此功能,保留自建实现。
|
||||
func (c *SoraSDKClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
|
||||
if parseURL == "" {
|
||||
return "", errors.New("custom parse url is required")
|
||||
}
|
||||
if strings.TrimSpace(parseToken) == "" {
|
||||
return "", errors.New("custom parse token is required")
|
||||
}
|
||||
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
|
||||
payload := map[string]any{
|
||||
"url": shareURL,
|
||||
"token": strings.TrimSpace(parseToken),
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
}
|
||||
var resp *http.Response
|
||||
if c.httpUpstream != nil {
|
||||
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
|
||||
}
|
||||
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
|
||||
if downloadLink == "" {
|
||||
return "", errors.New("custom parse response missing download_link")
|
||||
}
|
||||
return downloadLink, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(expansionLevel) == "" {
|
||||
expansionLevel = "medium"
|
||||
}
|
||||
if durationS <= 0 {
|
||||
durationS = 10
|
||||
}
|
||||
enhanced, err := sdkClient.EnhancePrompt(ctx, token, prompt, expansionLevel, durationS)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return enhanced, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := sdkClient.QueryImageTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second))
|
||||
if result.Err != nil {
|
||||
return &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: result.Err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if result.Done && result.ImageURL != "" {
|
||||
return &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "succeeded",
|
||||
URLs: []string{result.ImageURL},
|
||||
}, nil
|
||||
}
|
||||
status := result.Progress.Status
|
||||
if status == "" {
|
||||
status = "processing"
|
||||
}
|
||||
return &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
ProgressPct: float64(result.Progress.Percent) / 100.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先查询 pending 列表
|
||||
result := sdkClient.QueryVideoTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second), 0)
|
||||
if result.Err != nil {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: result.Err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if !result.Done {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: result.Progress.Status,
|
||||
ProgressPct: result.Progress.Percent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 任务不在 pending 中,查询 drafts 获取下载链接
|
||||
downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: errMsg,
|
||||
}, nil
|
||||
}
|
||||
// 可能还在处理中
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "processing",
|
||||
}, nil
|
||||
}
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
URLs: []string{downloadURL},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- 内部方法 ---
|
||||
|
||||
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
|
||||
func (c *SoraSDKClient) getSDKClient(account *Account) (*sora.Client, error) {
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
if v, ok := c.sdkClients.Load(proxyURL); ok {
|
||||
if cli, ok2 := v.(*sora.Client); ok2 {
|
||||
return cli, nil
|
||||
}
|
||||
}
|
||||
client, err := sora.New(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 Sora SDK 客户端失败: %w", err)
|
||||
}
|
||||
actual, _ := c.sdkClients.LoadOrStore(proxyURL, client)
|
||||
if cli, ok := actual.(*sora.Client); ok {
|
||||
return cli, nil
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) resolveProxyURL(account *Account) string {
|
||||
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(account.Proxy.URL())
|
||||
}
|
||||
|
||||
// getAccessToken 获取账号的 access_token,支持多种 token 来源和自动刷新。
|
||||
// 此方法保留了原 SoraDirectClient 的 token 管理逻辑。
|
||||
func (c *SoraSDKClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
|
||||
// 优先尝试 OpenAI Token Provider
|
||||
allowProvider := c.allowOpenAITokenProvider(account)
|
||||
var providerErr error
|
||||
if allowProvider && c.tokenProvider != nil {
|
||||
token, err := c.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err == nil && strings.TrimSpace(token) != "" {
|
||||
c.debugLogf("token_selected account_id=%d source=openai_token_provider", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
providerErr = err
|
||||
if err != nil && c.debugEnabled() {
|
||||
c.debugLogf("token_provider_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试直接使用 credentials 中的 access_token
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if token != "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
|
||||
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
|
||||
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
|
||||
return refreshed, nil
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 尝试通过 session_token 或 refresh_token 恢复
|
||||
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
|
||||
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||
return recovered, nil
|
||||
}
|
||||
if providerErr != nil {
|
||||
return "", providerErr
|
||||
}
|
||||
return "", errors.New("access_token not found")
|
||||
}
|
||||
|
||||
// recoverAccessToken 通过 session_token 或 refresh_token 恢复 access_token
|
||||
func (c *SoraSDKClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
|
||||
// 先尝试 session_token
|
||||
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
|
||||
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
|
||||
if err == nil && strings.TrimSpace(accessToken) != "" {
|
||||
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
|
||||
return accessToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 再尝试 refresh_token
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
if refreshToken == "" {
|
||||
return "", errors.New("session_token/refresh_token not found")
|
||||
}
|
||||
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 尝试多个 client_id
|
||||
clientIDs := []string{
|
||||
strings.TrimSpace(account.GetCredential("client_id")),
|
||||
openaioauth.SoraClientID,
|
||||
openaioauth.ClientID,
|
||||
}
|
||||
tried := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
|
||||
for _, clientID := range clientIDs {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := tried[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
tried[clientID] = struct{}{}
|
||||
|
||||
newAccess, newRefresh, refreshErr := sdkClient.RefreshAccessToken(ctx, refreshToken, clientID)
|
||||
if refreshErr != nil {
|
||||
lastErr = refreshErr
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(newAccess) == "" {
|
||||
lastErr = errors.New("refreshed access_token is empty")
|
||||
continue
|
||||
}
|
||||
c.applyRecoveredToken(ctx, account, newAccess, newRefresh, "", "")
|
||||
return newAccess, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", errors.New("no available client_id for refresh_token exchange")
|
||||
}
|
||||
|
||||
// exchangeSessionToken 通过 session_token 换取 access_token
|
||||
func (c *SoraSDKClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://sora.chatgpt.com/api/auth/session", nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if c.httpUpstream != nil {
|
||||
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", "", fmt.Errorf("session exchange failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("session exchange missing accessToken")
|
||||
}
|
||||
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
|
||||
return accessToken, expiresAt, nil
|
||||
}
|
||||
|
||||
// applyRecoveredToken 将恢复的 token 写入账号内存和数据库
|
||||
func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
if strings.TrimSpace(accessToken) != "" {
|
||||
account.Credentials["access_token"] = accessToken
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) != "" {
|
||||
account.Credentials["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(expiresAt) != "" {
|
||||
account.Credentials["expires_at"] = expiresAt
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
account.Credentials["session_token"] = sessionToken
|
||||
}
|
||||
|
||||
if c.accountRepo != nil {
|
||||
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
|
||||
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
updates := make(map[string]any)
|
||||
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
|
||||
updates["access_token"] = accessToken
|
||||
updates["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
updates["session_token"] = sessionToken
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) allowOpenAITokenProvider(account *Account) bool {
|
||||
if c == nil || c.tokenProvider == nil {
|
||||
return false
|
||||
}
|
||||
if account != nil && account.Platform == PlatformSora {
|
||||
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// wrapSDKError 将 SDK 错误包装为 SoraUpstreamError
|
||||
func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
msg := err.Error()
|
||||
statusCode := http.StatusBadGateway
|
||||
if strings.Contains(msg, "HTTP 401") || strings.Contains(msg, "HTTP 403") {
|
||||
statusCode = http.StatusUnauthorized
|
||||
} else if strings.Contains(msg, "HTTP 429") {
|
||||
statusCode = http.StatusTooManyRequests
|
||||
} else if strings.Contains(msg, "HTTP 404") {
|
||||
statusCode = http.StatusNotFound
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) debugEnabled() bool {
|
||||
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) debugLogf(format string, args ...any) {
|
||||
if c.debugEnabled() {
|
||||
log.Printf("[SoraSDK] "+format, args...)
|
||||
}
|
||||
}
|
||||
@@ -206,14 +206,14 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
return NewSoraMediaStorage(cfg)
|
||||
}
|
||||
|
||||
func ProvideSoraDirectClient(
|
||||
func ProvideSoraSDKClient(
|
||||
cfg *config.Config,
|
||||
httpUpstream HTTPUpstream,
|
||||
tokenProvider *OpenAITokenProvider,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
) *SoraDirectClient {
|
||||
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
|
||||
) *SoraSDKClient {
|
||||
client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider)
|
||||
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||
return client
|
||||
}
|
||||
@@ -306,8 +306,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewGatewayService,
|
||||
ProvideSoraMediaStorage,
|
||||
ProvideSoraMediaCleanupService,
|
||||
ProvideSoraDirectClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||
ProvideSoraSDKClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraSDKClient)),
|
||||
NewSoraGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
|
||||
42
backend/migrations/058_add_sonnet46_to_model_mapping.sql
Normal file
42
backend/migrations/058_add_sonnet46_to_model_mapping.sql
Normal file
@@ -0,0 +1,42 @@
|
||||
-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts
|
||||
--
|
||||
-- Background:
|
||||
-- Antigravity now supports claude-sonnet-4-6
|
||||
--
|
||||
-- Strategy:
|
||||
-- Directly overwrite the entire model_mapping with updated mappings
|
||||
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
|
||||
|
||||
UPDATE accounts
|
||||
SET credentials = jsonb_set(
|
||||
credentials,
|
||||
'{model_mapping}',
|
||||
'{
|
||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||
}'::jsonb
|
||||
)
|
||||
WHERE platform = 'antigravity'
|
||||
AND deleted_at IS NULL
|
||||
AND credentials->'model_mapping' IS NOT NULL;
|
||||
45
backend/migrations/059_add_gemini31_pro_to_model_mapping.sql
Normal file
45
backend/migrations/059_add_gemini31_pro_to_model_mapping.sql
Normal file
@@ -0,0 +1,45 @@
|
||||
-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping
|
||||
--
|
||||
-- Background:
|
||||
-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low
|
||||
--
|
||||
-- Strategy:
|
||||
-- Directly overwrite the entire model_mapping with updated mappings
|
||||
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
|
||||
|
||||
UPDATE accounts
|
||||
SET credentials = jsonb_set(
|
||||
credentials,
|
||||
'{model_mapping}',
|
||||
'{
|
||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||
}'::jsonb
|
||||
)
|
||||
WHERE platform = 'antigravity'
|
||||
AND deleted_at IS NULL
|
||||
AND credentials->'model_mapping' IS NOT NULL;
|
||||
@@ -0,0 +1,46 @@
|
||||
-- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping
|
||||
--
|
||||
-- Background:
|
||||
-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model,
|
||||
-- replacing the previous gemini-3-pro-image.
|
||||
--
|
||||
-- Strategy:
|
||||
-- Directly overwrite the entire model_mapping with updated mappings
|
||||
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
|
||||
|
||||
UPDATE accounts
|
||||
SET credentials = jsonb_set(
|
||||
credentials,
|
||||
'{model_mapping}',
|
||||
'{
|
||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-6": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview"
|
||||
}'::jsonb
|
||||
)
|
||||
WHERE platform = 'antigravity'
|
||||
AND deleted_at IS NULL
|
||||
AND credentials->'model_mapping' IS NOT NULL;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -781,10 +781,10 @@ rate_limit:
|
||||
pricing:
|
||||
# URL to fetch model pricing data (default: LiteLLM)
|
||||
# 获取模型定价数据的 URL(默认:LiteLLM)
|
||||
remote_url: "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
remote_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json"
|
||||
# Hash verification URL (optional)
|
||||
# 哈希校验 URL(可选)
|
||||
hash_url: ""
|
||||
hash_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256"
|
||||
# Local data directory for caching
|
||||
# 本地数据缓存目录
|
||||
data_dir: "./data"
|
||||
|
||||
@@ -173,7 +173,6 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
- PGDATA=/var/lib/postgresql/data
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
networks:
|
||||
- sub2api-network
|
||||
|
||||
@@ -15,7 +15,9 @@ import type {
|
||||
AccountUsageStatsResponse,
|
||||
TempUnschedulableStatus,
|
||||
AdminDataPayload,
|
||||
AdminDataImportResult
|
||||
AdminDataImportResult,
|
||||
CheckMixedChannelRequest,
|
||||
CheckMixedChannelResponse
|
||||
} from '@/types'
|
||||
|
||||
/**
|
||||
@@ -133,6 +135,16 @@ export async function update(id: number, updates: UpdateAccountRequest): Promise
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Check mixed-channel risk for account-group binding.
|
||||
*/
|
||||
export async function checkMixedChannelRisk(
|
||||
payload: CheckMixedChannelRequest
|
||||
): Promise<CheckMixedChannelResponse> {
|
||||
const { data } = await apiClient.post<CheckMixedChannelResponse>('/admin/accounts/check-mixed-channel', payload)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete account
|
||||
* @param id - Account ID
|
||||
@@ -535,6 +547,7 @@ export const accountsAPI = {
|
||||
getById,
|
||||
create,
|
||||
update,
|
||||
checkMixedChannelRisk,
|
||||
delete: deleteAccount,
|
||||
toggleStatus,
|
||||
testAccount,
|
||||
|
||||
@@ -77,13 +77,23 @@
|
||||
</div>
|
||||
|
||||
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
|
||||
<template v-if="activeModelRateLimits.length > 0">
|
||||
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
|
||||
<div
|
||||
v-if="activeModelRateLimits.length > 0"
|
||||
:class="[
|
||||
activeModelRateLimits.length <= 4
|
||||
? 'flex flex-col gap-1'
|
||||
: activeModelRateLimits.length <= 8
|
||||
? 'columns-2 gap-x-2'
|
||||
: 'columns-3 gap-x-2'
|
||||
]"
|
||||
>
|
||||
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative mb-1 break-inside-avoid">
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded bg-purple-100 px-1.5 py-0.5 text-xs font-medium text-purple-700 dark:bg-purple-900/30 dark:text-purple-400"
|
||||
>
|
||||
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
|
||||
{{ formatScopeName(item.model) }}
|
||||
<span class="text-[10px] opacity-70">{{ formatModelResetTime(item.reset_at) }}</span>
|
||||
</span>
|
||||
<!-- Tooltip -->
|
||||
<div
|
||||
@@ -95,7 +105,7 @@
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- Overload Indicator (529) -->
|
||||
<div v-if="isOverloaded" class="group relative">
|
||||
@@ -154,17 +164,52 @@ const activeModelRateLimits = computed(() => {
|
||||
})
|
||||
|
||||
const formatScopeName = (scope: string): string => {
|
||||
const names: Record<string, string> = {
|
||||
const aliases: Record<string, string> = {
|
||||
// Claude 系列
|
||||
'claude-opus-4-6': 'COpus46',
|
||||
'claude-opus-4-6-thinking': 'COpus46T',
|
||||
'claude-sonnet-4-6': 'CSon46',
|
||||
'claude-sonnet-4-5': 'CSon45',
|
||||
'claude-sonnet-4-5-thinking': 'CSon45T',
|
||||
// Gemini 2.5 系列
|
||||
'gemini-2.5-flash': 'G25F',
|
||||
'gemini-2.5-flash-lite': 'G25FL',
|
||||
'gemini-2.5-flash-thinking': 'G25FT',
|
||||
'gemini-2.5-pro': 'G25P',
|
||||
// Gemini 3 系列
|
||||
'gemini-3-flash': 'G3F',
|
||||
'gemini-3.1-pro-high': 'G3PH',
|
||||
'gemini-3.1-pro-low': 'G3PL',
|
||||
'gemini-3-pro-image': 'G3PI',
|
||||
'gemini-3.1-flash-image': 'GImage',
|
||||
// 其他
|
||||
'gpt-oss-120b-medium': 'GPT120',
|
||||
'tab_flash_lite_preview': 'TabFL',
|
||||
// 旧版 scope 别名(兼容)
|
||||
claude: 'Claude',
|
||||
claude_sonnet: 'Claude Sonnet',
|
||||
claude_opus: 'Claude Opus',
|
||||
claude_haiku: 'Claude Haiku',
|
||||
claude_sonnet: 'CSon',
|
||||
claude_opus: 'COpus',
|
||||
claude_haiku: 'CHaiku',
|
||||
gemini_text: 'Gemini',
|
||||
gemini_image: 'Image',
|
||||
gemini_flash: 'Gemini Flash',
|
||||
gemini_pro: 'Gemini Pro'
|
||||
gemini_image: 'GImg',
|
||||
gemini_flash: 'GFlash',
|
||||
gemini_pro: 'GPro',
|
||||
}
|
||||
return names[scope] || scope
|
||||
return aliases[scope] || scope
|
||||
}
|
||||
|
||||
const formatModelResetTime = (resetAt: string): string => {
|
||||
const date = new Date(resetAt)
|
||||
const now = new Date()
|
||||
const diffMs = date.getTime() - now.getTime()
|
||||
if (diffMs <= 0) return ''
|
||||
const totalSecs = Math.floor(diffMs / 1000)
|
||||
const h = Math.floor(totalSecs / 3600)
|
||||
const m = Math.floor((totalSecs % 3600) / 60)
|
||||
const s = totalSecs % 60
|
||||
if (h > 0) return `${h}h${m}m`
|
||||
if (m > 0) return `${m}m${s}s`
|
||||
return `${s}s`
|
||||
}
|
||||
|
||||
// Computed: is overloaded (529)
|
||||
|
||||
@@ -172,12 +172,12 @@
|
||||
color="purple"
|
||||
/>
|
||||
|
||||
<!-- Claude 4.5 -->
|
||||
<!-- Claude -->
|
||||
<UsageProgressBar
|
||||
v-if="antigravityClaude45UsageFromAPI !== null"
|
||||
:label="t('admin.accounts.usageWindow.claude45')"
|
||||
:utilization="antigravityClaude45UsageFromAPI.utilization"
|
||||
:resets-at="antigravityClaude45UsageFromAPI.resetTime"
|
||||
v-if="antigravityClaudeUsageFromAPI !== null"
|
||||
:label="t('admin.accounts.usageWindow.claude')"
|
||||
:utilization="antigravityClaudeUsageFromAPI.utilization"
|
||||
:resets-at="antigravityClaudeUsageFromAPI.resetTime"
|
||||
color="amber"
|
||||
/>
|
||||
</div>
|
||||
@@ -397,12 +397,15 @@ const antigravity3ProUsageFromAPI = computed(() =>
|
||||
// Gemini 3 Flash from API
|
||||
const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-flash']))
|
||||
|
||||
// Gemini 3 Image from API
|
||||
const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image']))
|
||||
// Gemini Image from API
|
||||
const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3.1-flash-image']))
|
||||
|
||||
// Claude 4.5 from API
|
||||
const antigravityClaude45UsageFromAPI = computed(() =>
|
||||
getAntigravityUsageFromAPI(['claude-sonnet-4-5', 'claude-opus-4-5-thinking'])
|
||||
// Claude from API (all Claude model variants)
|
||||
const antigravityClaudeUsageFromAPI = computed(() =>
|
||||
getAntigravityUsageFromAPI([
|
||||
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
|
||||
'claude-sonnet-4-6', 'claude-opus-4-6', 'claude-opus-4-6-thinking',
|
||||
])
|
||||
)
|
||||
|
||||
// Antigravity 账户类型(从 load_code_assist 响应中提取)
|
||||
|
||||
@@ -21,6 +21,16 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mixed platform warning -->
|
||||
<div v-if="isMixedPlatform" class="rounded-lg bg-amber-50 p-4 dark:bg-amber-900/20">
|
||||
<p class="text-sm text-amber-700 dark:text-amber-400">
|
||||
<svg class="mr-1.5 inline h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" />
|
||||
</svg>
|
||||
{{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Base URL (API Key only) -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
@@ -157,7 +167,7 @@
|
||||
<!-- Model Checkbox List -->
|
||||
<div class="mb-3 grid grid-cols-2 gap-2">
|
||||
<label
|
||||
v-for="model in allModels"
|
||||
v-for="model in filteredModels"
|
||||
:key="model.value"
|
||||
class="flex cursor-pointer items-center rounded-lg border p-3 transition-all hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700"
|
||||
:class="
|
||||
@@ -209,7 +219,7 @@
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="getModelMappingKey(mapping)"
|
||||
:key="index"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
@@ -278,7 +288,7 @@
|
||||
<!-- Quick Add Buttons -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in presetMappings"
|
||||
v-for="preset in filteredPresets"
|
||||
:key="preset.label"
|
||||
type="button"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
@@ -648,18 +658,19 @@ import { ref, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Proxy, AdminGroup } from '@/types'
|
||||
import type { Proxy as ProxyConfig, AdminGroup, AccountPlatform } from '@/types'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
accountIds: number[]
|
||||
proxies: Proxy[]
|
||||
selectedPlatforms: AccountPlatform[]
|
||||
proxies: ProxyConfig[]
|
||||
groups: AdminGroup[]
|
||||
}
|
||||
|
||||
@@ -672,6 +683,31 @@ const emit = defineEmits<{
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
// Platform awareness
|
||||
const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
|
||||
|
||||
const platformModelPrefix: Record<string, string[]> = {
|
||||
anthropic: ['claude-'],
|
||||
antigravity: ['claude-'],
|
||||
openai: ['gpt-'],
|
||||
gemini: ['gemini-'],
|
||||
sora: []
|
||||
}
|
||||
|
||||
const filteredModels = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return allModels
|
||||
const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))]
|
||||
if (prefixes.length === 0) return allModels
|
||||
return allModels.filter(m => prefixes.some(prefix => m.value.startsWith(prefix)))
|
||||
})
|
||||
|
||||
const filteredPresets = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return presetMappings
|
||||
const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))]
|
||||
if (prefixes.length === 0) return presetMappings
|
||||
return presetMappings.filter(m => prefixes.some(prefix => m.from.startsWith(prefix)))
|
||||
})
|
||||
|
||||
// Model mapping type
|
||||
interface ModelMapping {
|
||||
from: string
|
||||
@@ -696,7 +732,6 @@ const baseUrl = ref('')
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
|
||||
const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
@@ -707,7 +742,7 @@ const rateMultiplier = ref(1)
|
||||
const status = ref<'active' | 'inactive'>('active')
|
||||
const groupIds = ref<number[]>([])
|
||||
|
||||
// All models list (combined Anthropic + OpenAI)
|
||||
// All models list (combined Anthropic + OpenAI + Gemini)
|
||||
const allModels = [
|
||||
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
||||
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
|
||||
@@ -719,6 +754,7 @@ const allModels = [
|
||||
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
|
||||
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
||||
{ value: 'gpt-5.3-codex', label: 'GPT-5.3 Codex' },
|
||||
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
@@ -726,10 +762,15 @@ const allModels = [
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' }
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
||||
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
|
||||
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
|
||||
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
|
||||
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
|
||||
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
|
||||
]
|
||||
|
||||
// Preset mappings (combined Anthropic + OpenAI)
|
||||
// Preset mappings (combined Anthropic + OpenAI + Gemini)
|
||||
const presetMappings = [
|
||||
{
|
||||
label: 'Sonnet 4',
|
||||
@@ -754,7 +795,14 @@ const presetMappings = [
|
||||
{
|
||||
label: 'Opus 4.6',
|
||||
from: 'claude-opus-4-6',
|
||||
to: 'claude-opus-4-6',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.6-thinking',
|
||||
from: 'claude-opus-4-6-thinking',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
@@ -765,6 +813,31 @@ const presetMappings = [
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet4→4.6',
|
||||
from: 'claude-sonnet-4-20250514',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet4.5→4.6',
|
||||
from: 'claude-sonnet-4-5-20250929',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet3.5→4.6',
|
||||
from: 'claude-3-5-sonnet-20241022',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus4.5→4.6',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus->Sonnet',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
@@ -772,10 +845,22 @@ const presetMappings = [
|
||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.3 Codex Spark',
|
||||
label: 'GPT-5.3 Codex',
|
||||
from: 'gpt-5.3-codex',
|
||||
to: 'gpt-5.3-codex',
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.3 Spark',
|
||||
from: 'gpt-5.3-codex-spark',
|
||||
to: 'gpt-5.3-codex-spark',
|
||||
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: '5.2→5.3',
|
||||
from: 'gpt-5.2-codex',
|
||||
to: 'gpt-5.3-codex',
|
||||
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2',
|
||||
@@ -794,6 +879,36 @@ const presetMappings = [
|
||||
from: 'gpt-5.1-codex-max',
|
||||
to: 'gpt-5.1-codex',
|
||||
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-Preview→3.1-Pro-High',
|
||||
from: 'gemini-3-pro-preview',
|
||||
to: 'gemini-3.1-pro-high',
|
||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-High→3.1-Pro-High',
|
||||
from: 'gemini-3-pro-high',
|
||||
to: 'gemini-3.1-pro-high',
|
||||
color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-Low→3.1-Pro-Low',
|
||||
from: 'gemini-3-pro-low',
|
||||
to: 'gemini-3.1-pro-low',
|
||||
color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||
},
|
||||
{
|
||||
label: '3-Flash透传',
|
||||
from: 'gemini-3-flash',
|
||||
to: 'gemini-3-flash',
|
||||
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||
},
|
||||
{
|
||||
label: '2.5-Flash-Lite透传',
|
||||
from: 'gemini-2.5-flash-lite',
|
||||
to: 'gemini-2.5-flash-lite',
|
||||
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
|
||||
}
|
||||
]
|
||||
|
||||
@@ -883,23 +998,11 @@ const removeErrorCode = (code: number) => {
|
||||
}
|
||||
|
||||
const buildModelMappingObject = (): Record<string, string> | null => {
|
||||
const mapping: Record<string, string> = {}
|
||||
|
||||
if (modelRestrictionMode.value === 'whitelist') {
|
||||
for (const model of allowedModels.value) {
|
||||
mapping[model] = model
|
||||
}
|
||||
} else {
|
||||
for (const m of modelMappings.value) {
|
||||
const from = m.from.trim()
|
||||
const to = m.to.trim()
|
||||
if (from && to) {
|
||||
mapping[from] = to
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Object.keys(mapping).length > 0 ? mapping : null
|
||||
return buildModelMappingPayload(
|
||||
modelRestrictionMode.value,
|
||||
allowedModels.value,
|
||||
modelMappings.value
|
||||
)
|
||||
}
|
||||
|
||||
const buildUpdatePayload = (): Record<string, unknown> | null => {
|
||||
|
||||
@@ -916,8 +916,8 @@
|
||||
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section (不适用于 Gemini,Antigravity 已在上层条件排除) -->
|
||||
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<!-- Model Restriction Section (Antigravity 已在上层条件排除) -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<div
|
||||
@@ -1200,34 +1200,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Gemini 模型说明 -->
|
||||
<div v-if="form.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
|
||||
<div class="flex items-start gap-3">
|
||||
<svg
|
||||
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
<div>
|
||||
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
|
||||
{{ t('admin.accounts.gemini.modelPassthrough') }}
|
||||
</p>
|
||||
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
|
||||
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Temp Unschedulable Rules -->
|
||||
@@ -1378,9 +1350,9 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Intercept Warmup Requests (Anthropic only) -->
|
||||
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
|
||||
<div
|
||||
v-if="form.platform === 'anthropic'"
|
||||
v-if="form.platform === 'anthropic' || form.platform === 'antigravity'"
|
||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
@@ -1844,12 +1816,14 @@
|
||||
:show-cookie-option="form.platform === 'anthropic'"
|
||||
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
|
||||
:show-session-token-option="form.platform === 'sora'"
|
||||
:show-access-token-option="form.platform === 'sora'"
|
||||
:platform="form.platform"
|
||||
:show-project-id="geminiOAuthType === 'code_assist'"
|
||||
@generate-url="handleGenerateUrl"
|
||||
@cookie-auth="handleCookieAuth"
|
||||
@validate-refresh-token="handleValidateRefreshToken"
|
||||
@validate-session-token="handleValidateSessionToken"
|
||||
@import-access-token="handleImportAccessToken"
|
||||
/>
|
||||
|
||||
</div>
|
||||
@@ -2157,7 +2131,7 @@
|
||||
<ConfirmDialog
|
||||
:show="showMixedChannelWarning"
|
||||
:title="t('admin.accounts.mixedChannelWarningTitle')"
|
||||
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
|
||||
:message="mixedChannelWarningMessageText"
|
||||
:confirm-text="t('common.confirm')"
|
||||
:cancel-text="t('common.cancel')"
|
||||
:danger="true"
|
||||
@@ -2189,13 +2163,21 @@ import {
|
||||
import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
|
||||
import { useGeminiOAuth } from '@/composables/useGeminiOAuth'
|
||||
import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth'
|
||||
import type { Proxy, AdminGroup, AccountPlatform, AccountType } from '@/types'
|
||||
import type {
|
||||
Proxy,
|
||||
AdminGroup,
|
||||
AccountPlatform,
|
||||
AccountType,
|
||||
CheckMixedChannelResponse,
|
||||
CreateAccountRequest
|
||||
} from '@/types'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||
@@ -2337,10 +2319,13 @@ const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>
|
||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
||||
const geminiAIStudioOAuthEnabled = ref(false)
|
||||
|
||||
// Mixed channel warning dialog state
|
||||
const showMixedChannelWarning = ref(false)
|
||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
|
||||
const pendingCreatePayload = ref<any>(null)
|
||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
|
||||
null
|
||||
)
|
||||
const mixedChannelWarningRawMessage = ref('')
|
||||
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
|
||||
const antigravityMixedChannelConfirmed = ref(false)
|
||||
const showAdvancedOAuth = ref(false)
|
||||
const showGeminiHelpDialog = ref(false)
|
||||
|
||||
@@ -2378,6 +2363,13 @@ const isOpenAIModelRestrictionDisabled = computed(() =>
|
||||
form.platform === 'openai' && openaiPassthroughEnabled.value
|
||||
)
|
||||
|
||||
const mixedChannelWarningMessageText = computed(() => {
|
||||
if (mixedChannelWarningDetails.value) {
|
||||
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
|
||||
}
|
||||
return mixedChannelWarningRawMessage.value
|
||||
})
|
||||
|
||||
const geminiQuotaDocs = {
|
||||
codeAssist: 'https://developers.google.com/gemini-code-assist/resources/quotas',
|
||||
aiStudio: 'https://ai.google.dev/pricing',
|
||||
@@ -2544,8 +2536,8 @@ watch(
|
||||
antigravityModelMappings.value = []
|
||||
antigravityModelRestrictionMode.value = 'mapping'
|
||||
}
|
||||
// Reset Anthropic-specific settings when switching to other platforms
|
||||
if (newPlatform !== 'anthropic') {
|
||||
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||
interceptWarmupRequests.value = false
|
||||
}
|
||||
if (newPlatform === 'sora') {
|
||||
@@ -2794,6 +2786,105 @@ const splitTempUnschedKeywords = (value: string) => {
|
||||
.filter((item) => item.length > 0)
|
||||
}
|
||||
|
||||
const needsMixedChannelCheck = (platform: AccountPlatform) => platform === 'antigravity' || platform === 'anthropic'
|
||||
|
||||
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
|
||||
const details = resp?.details
|
||||
if (!details) {
|
||||
return null
|
||||
}
|
||||
return {
|
||||
groupName: details.group_name || 'Unknown',
|
||||
currentPlatform: details.current_platform || 'Unknown',
|
||||
otherPlatform: details.other_platform || 'Unknown'
|
||||
}
|
||||
}
|
||||
|
||||
const clearMixedChannelDialog = () => {
|
||||
showMixedChannelWarning.value = false
|
||||
mixedChannelWarningDetails.value = null
|
||||
mixedChannelWarningRawMessage.value = ''
|
||||
mixedChannelWarningAction.value = null
|
||||
}
|
||||
|
||||
const openMixedChannelDialog = (opts: {
|
||||
response?: CheckMixedChannelResponse
|
||||
message?: string
|
||||
onConfirm: () => Promise<void>
|
||||
}) => {
|
||||
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
|
||||
mixedChannelWarningRawMessage.value =
|
||||
opts.message || opts.response?.message || t('admin.accounts.failedToCreate')
|
||||
mixedChannelWarningAction.value = opts.onConfirm
|
||||
showMixedChannelWarning.value = true
|
||||
}
|
||||
|
||||
const withAntigravityConfirmFlag = (payload: CreateAccountRequest): CreateAccountRequest => {
|
||||
if (needsMixedChannelCheck(payload.platform) && antigravityMixedChannelConfirmed.value) {
|
||||
return {
|
||||
...payload,
|
||||
confirm_mixed_channel_risk: true
|
||||
}
|
||||
}
|
||||
const cloned = { ...payload }
|
||||
delete cloned.confirm_mixed_channel_risk
|
||||
return cloned
|
||||
}
|
||||
|
||||
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
|
||||
if (!needsMixedChannelCheck(form.platform)) {
|
||||
return true
|
||||
}
|
||||
if (antigravityMixedChannelConfirmed.value) {
|
||||
return true
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
||||
platform: form.platform,
|
||||
group_ids: form.group_ids
|
||||
})
|
||||
if (!result.has_risk) {
|
||||
return true
|
||||
}
|
||||
openMixedChannelDialog({
|
||||
response: result,
|
||||
onConfirm: async () => {
|
||||
antigravityMixedChannelConfirmed.value = true
|
||||
await onConfirm()
|
||||
}
|
||||
})
|
||||
return false
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const submitCreateAccount = async (payload: CreateAccountRequest) => {
|
||||
submitting.value = true
|
||||
try {
|
||||
await adminAPI.accounts.create(withAntigravityConfirmFlag(payload))
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
emit('created')
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck(form.platform)) {
|
||||
openMixedChannelDialog({
|
||||
message: error.response?.data?.message,
|
||||
onConfirm: async () => {
|
||||
antigravityMixedChannelConfirmed.value = true
|
||||
await submitCreateAccount(payload)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Methods
|
||||
const resetForm = () => {
|
||||
step.value = 1
|
||||
@@ -2855,9 +2946,13 @@ const resetForm = () => {
|
||||
geminiOAuth.resetState()
|
||||
antigravityOAuth.resetState()
|
||||
oauthFlowRef.value?.reset()
|
||||
antigravityMixedChannelConfirmed.value = false
|
||||
clearMixedChannelDialog()
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
antigravityMixedChannelConfirmed.value = false
|
||||
clearMixedChannelDialog()
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -2916,56 +3011,34 @@ const buildSoraExtra = (
|
||||
}
|
||||
|
||||
// Helper function to create account with mixed channel warning handling
|
||||
const doCreateAccount = async (payload: any) => {
|
||||
const doCreateAccount = async (payload: CreateAccountRequest) => {
|
||||
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||
await submitCreateAccount(payload)
|
||||
})
|
||||
if (!canContinue) {
|
||||
return
|
||||
}
|
||||
await submitCreateAccount(payload)
|
||||
}
|
||||
|
||||
// Handle mixed channel warning confirmation
|
||||
const handleMixedChannelConfirm = async () => {
|
||||
const action = mixedChannelWarningAction.value
|
||||
if (!action) {
|
||||
clearMixedChannelDialog()
|
||||
return
|
||||
}
|
||||
clearMixedChannelDialog()
|
||||
submitting.value = true
|
||||
try {
|
||||
await adminAPI.accounts.create(payload)
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
emit('created')
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
// Handle 409 mixed_channel_warning - show confirmation dialog
|
||||
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
|
||||
const details = error.response.data.details || {}
|
||||
mixedChannelWarningDetails.value = {
|
||||
groupName: details.group_name || 'Unknown',
|
||||
currentPlatform: details.current_platform || 'Unknown',
|
||||
otherPlatform: details.other_platform || 'Unknown'
|
||||
}
|
||||
pendingCreatePayload.value = payload
|
||||
showMixedChannelWarning.value = true
|
||||
} else {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||
}
|
||||
await action()
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Handle mixed channel warning confirmation
|
||||
const handleMixedChannelConfirm = async () => {
|
||||
showMixedChannelWarning.value = false
|
||||
if (pendingCreatePayload.value) {
|
||||
pendingCreatePayload.value.confirm_mixed_channel_risk = true
|
||||
submitting.value = true
|
||||
try {
|
||||
await adminAPI.accounts.create(pendingCreatePayload.value)
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
emit('created')
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||
} finally {
|
||||
submitting.value = false
|
||||
pendingCreatePayload.value = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleMixedChannelCancel = () => {
|
||||
showMixedChannelWarning.value = false
|
||||
pendingCreatePayload.value = null
|
||||
mixedChannelWarningDetails.value = null
|
||||
clearMixedChannelDialog()
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
@@ -2975,6 +3048,12 @@ const handleSubmit = async () => {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||
step.value = 2
|
||||
})
|
||||
if (!canContinue) {
|
||||
return
|
||||
}
|
||||
step.value = 2
|
||||
return
|
||||
}
|
||||
@@ -3010,15 +3089,10 @@ const handleSubmit = async () => {
|
||||
credentials.model_mapping = antigravityModelMapping
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
||||
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
|
||||
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3059,10 +3133,7 @@ const handleSubmit = async () => {
|
||||
credentials.custom_error_codes = [...selectedErrorCodes.value]
|
||||
}
|
||||
|
||||
// Add intercept warmup requests setting
|
||||
if (interceptWarmupRequests.value) {
|
||||
credentials.intercept_warmup_requests = true
|
||||
}
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
if (!applyTempUnschedConfig(credentials)) {
|
||||
return
|
||||
}
|
||||
@@ -3119,6 +3190,83 @@ const handleValidateSessionToken = (sessionToken: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Sora 手动 AT 批量导入
|
||||
const handleImportAccessToken = async (accessTokenInput: string) => {
|
||||
const oauthClient = activeOpenAIOAuth.value
|
||||
if (!accessTokenInput.trim()) return
|
||||
|
||||
const accessTokens = accessTokenInput
|
||||
.split('\n')
|
||||
.map((at) => at.trim())
|
||||
.filter((at) => at)
|
||||
|
||||
if (accessTokens.length === 0) {
|
||||
oauthClient.error.value = 'Please enter at least one Access Token'
|
||||
return
|
||||
}
|
||||
|
||||
oauthClient.loading.value = true
|
||||
oauthClient.error.value = ''
|
||||
|
||||
let successCount = 0
|
||||
let failedCount = 0
|
||||
const errors: string[] = []
|
||||
|
||||
try {
|
||||
for (let i = 0; i < accessTokens.length; i++) {
|
||||
try {
|
||||
const credentials: Record<string, unknown> = {
|
||||
access_token: accessTokens[i],
|
||||
}
|
||||
const soraExtra = buildSoraExtra()
|
||||
|
||||
const accountName = accessTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||
await adminAPI.accounts.create({
|
||||
name: accountName,
|
||||
notes: form.notes,
|
||||
platform: 'sora',
|
||||
type: 'oauth',
|
||||
credentials,
|
||||
extra: soraExtra,
|
||||
proxy_id: form.proxy_id,
|
||||
concurrency: form.concurrency,
|
||||
priority: form.priority,
|
||||
rate_multiplier: form.rate_multiplier,
|
||||
group_ids: form.group_ids,
|
||||
expires_at: form.expires_at,
|
||||
auto_pause_on_expired: autoPauseOnExpired.value
|
||||
})
|
||||
successCount++
|
||||
} catch (error: any) {
|
||||
failedCount++
|
||||
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
|
||||
errors.push(`#${i + 1}: ${errMsg}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (successCount > 0 && failedCount === 0) {
|
||||
appStore.showSuccess(
|
||||
accessTokens.length > 1
|
||||
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
|
||||
: t('admin.accounts.accountCreated')
|
||||
)
|
||||
emit('created')
|
||||
handleClose()
|
||||
} else if (successCount > 0 && failedCount > 0) {
|
||||
appStore.showWarning(
|
||||
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
||||
)
|
||||
oauthClient.error.value = errors.join('\n')
|
||||
emit('created')
|
||||
} else {
|
||||
oauthClient.error.value = errors.join('\n')
|
||||
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
||||
}
|
||||
} finally {
|
||||
oauthClient.loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const formatDateTimeLocal = formatDateTimeLocalInput
|
||||
const parseDateTimeLocal = parseDateTimeLocalInput
|
||||
|
||||
@@ -3132,7 +3280,7 @@ const createAccountAndFinish = async (
|
||||
if (!applyTempUnschedConfig(credentials)) {
|
||||
return
|
||||
}
|
||||
await adminAPI.accounts.create({
|
||||
await doCreateAccount({
|
||||
name: form.name,
|
||||
notes: form.notes,
|
||||
platform,
|
||||
@@ -3147,9 +3295,6 @@ const createAccountAndFinish = async (
|
||||
expires_at: form.expires_at,
|
||||
auto_pause_on_expired: autoPauseOnExpired.value
|
||||
})
|
||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||
emit('created')
|
||||
handleClose()
|
||||
}
|
||||
|
||||
// OpenAI OAuth 授权码兑换
|
||||
@@ -3497,7 +3642,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
|
||||
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||
|
||||
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
|
||||
await adminAPI.accounts.create({
|
||||
const createPayload = withAntigravityConfirmFlag({
|
||||
name: accountName,
|
||||
notes: form.notes,
|
||||
platform: 'antigravity',
|
||||
@@ -3512,6 +3657,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
|
||||
expires_at: form.expires_at,
|
||||
auto_pause_on_expired: autoPauseOnExpired.value
|
||||
})
|
||||
await adminAPI.accounts.create(createPayload)
|
||||
successCount++
|
||||
} catch (error: any) {
|
||||
failedCount++
|
||||
@@ -3606,6 +3752,7 @@ const handleAntigravityExchange = async (authCode: string) => {
|
||||
if (!tokenInfo) return
|
||||
|
||||
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
// Antigravity 只使用映射模式
|
||||
const antigravityModelMapping = buildModelMappingObject(
|
||||
'mapping',
|
||||
@@ -3677,10 +3824,8 @@ const handleAnthropicExchange = async (authCode: string) => {
|
||||
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
|
||||
}
|
||||
|
||||
const credentials = {
|
||||
...tokenInfo,
|
||||
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
|
||||
}
|
||||
const credentials: Record<string, unknown> = { ...tokenInfo }
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
|
||||
} catch (error: any) {
|
||||
oauth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
@@ -3779,11 +3924,8 @@ const handleCookieAuth = async (sessionKey: string) => {
|
||||
|
||||
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||
|
||||
// Merge interceptWarmupRequests into credentials
|
||||
const credentials: Record<string, unknown> = {
|
||||
...tokenInfo,
|
||||
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
|
||||
}
|
||||
const credentials: Record<string, unknown> = { ...tokenInfo }
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
if (tempUnschedEnabled.value) {
|
||||
credentials.temp_unschedulable_enabled = true
|
||||
credentials.temp_unschedulable_rules = tempUnschedPayload
|
||||
|
||||
@@ -65,8 +65,8 @@
|
||||
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section (不适用于 Gemini 和 Antigravity) -->
|
||||
<div v-if="account.platform !== 'gemini' && account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<!-- Model Restriction Section (不适用于 Antigravity) -->
|
||||
<div v-if="account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<div
|
||||
@@ -349,34 +349,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Gemini 模型说明 -->
|
||||
<div v-if="account.platform === 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<div class="rounded-lg bg-blue-50 p-4 dark:bg-blue-900/20">
|
||||
<div class="flex items-start gap-3">
|
||||
<svg
|
||||
class="h-5 w-5 flex-shrink-0 text-blue-600 dark:text-blue-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
<div>
|
||||
<p class="text-sm font-medium text-blue-800 dark:text-blue-300">
|
||||
{{ t('admin.accounts.gemini.modelPassthrough') }}
|
||||
</p>
|
||||
<p class="mt-1 text-xs text-blue-700 dark:text-blue-400">
|
||||
{{ t('admin.accounts.gemini.modelPassthroughDesc') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Upstream fields (only for upstream type) -->
|
||||
@@ -641,9 +613,9 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Intercept Warmup Requests (Anthropic only) -->
|
||||
<!-- Intercept Warmup Requests (Anthropic/Antigravity) -->
|
||||
<div
|
||||
v-if="account?.platform === 'anthropic'"
|
||||
v-if="account?.platform === 'anthropic' || account?.platform === 'antigravity'"
|
||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
@@ -692,6 +664,7 @@
|
||||
class="input"
|
||||
data-tour="account-form-priority"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.priorityHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label>
|
||||
@@ -1139,7 +1112,7 @@
|
||||
<ConfirmDialog
|
||||
:show="showMixedChannelWarning"
|
||||
:title="t('admin.accounts.mixedChannelWarningTitle')"
|
||||
:message="mixedChannelWarningDetails ? t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails) : ''"
|
||||
:message="mixedChannelWarningMessageText"
|
||||
:confirm-text="t('common.confirm')"
|
||||
:cancel-text="t('common.cancel')"
|
||||
:danger="true"
|
||||
@@ -1154,7 +1127,7 @@ import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, Proxy, AdminGroup } from '@/types'
|
||||
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
@@ -1162,6 +1135,7 @@ import Icon from '@/components/icons/Icon.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import {
|
||||
@@ -1233,10 +1207,13 @@ const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-mod
|
||||
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
|
||||
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
|
||||
|
||||
// Mixed channel warning dialog state
|
||||
const showMixedChannelWarning = ref(false)
|
||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(null)
|
||||
const pendingUpdatePayload = ref<Record<string, unknown> | null>(null)
|
||||
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
|
||||
null
|
||||
)
|
||||
const mixedChannelWarningRawMessage = ref('')
|
||||
const mixedChannelWarningAction = ref<(() => Promise<void>) | null>(null)
|
||||
const antigravityMixedChannelConfirmed = ref(false)
|
||||
|
||||
// Quota control state (Anthropic OAuth/SetupToken only)
|
||||
const windowCostEnabled = ref(false)
|
||||
@@ -1297,6 +1274,13 @@ const defaultBaseUrl = computed(() => {
|
||||
return 'https://api.anthropic.com'
|
||||
})
|
||||
|
||||
const mixedChannelWarningMessageText = computed(() => {
|
||||
if (mixedChannelWarningDetails.value) {
|
||||
return t('admin.accounts.mixedChannelWarning', mixedChannelWarningDetails.value)
|
||||
}
|
||||
return mixedChannelWarningRawMessage.value
|
||||
})
|
||||
|
||||
const form = reactive({
|
||||
name: '',
|
||||
notes: '',
|
||||
@@ -1326,6 +1310,11 @@ watch(
|
||||
() => props.account,
|
||||
(newAccount) => {
|
||||
if (newAccount) {
|
||||
antigravityMixedChannelConfirmed.value = false
|
||||
showMixedChannelWarning.value = false
|
||||
mixedChannelWarningDetails.value = null
|
||||
mixedChannelWarningRawMessage.value = ''
|
||||
mixedChannelWarningAction.value = null
|
||||
form.name = newAccount.name
|
||||
form.notes = newAccount.notes || ''
|
||||
form.proxy_id = newAccount.proxy_id
|
||||
@@ -1725,18 +1714,123 @@ function toPositiveNumber(value: unknown) {
|
||||
return Math.trunc(num)
|
||||
}
|
||||
|
||||
const needsMixedChannelCheck = () => props.account?.platform === 'antigravity' || props.account?.platform === 'anthropic'
|
||||
|
||||
const buildMixedChannelDetails = (resp?: CheckMixedChannelResponse) => {
|
||||
const details = resp?.details
|
||||
if (!details) {
|
||||
return null
|
||||
}
|
||||
return {
|
||||
groupName: details.group_name || 'Unknown',
|
||||
currentPlatform: details.current_platform || 'Unknown',
|
||||
otherPlatform: details.other_platform || 'Unknown'
|
||||
}
|
||||
}
|
||||
|
||||
const clearMixedChannelDialog = () => {
|
||||
showMixedChannelWarning.value = false
|
||||
mixedChannelWarningDetails.value = null
|
||||
mixedChannelWarningRawMessage.value = ''
|
||||
mixedChannelWarningAction.value = null
|
||||
}
|
||||
|
||||
const openMixedChannelDialog = (opts: {
|
||||
response?: CheckMixedChannelResponse
|
||||
message?: string
|
||||
onConfirm: () => Promise<void>
|
||||
}) => {
|
||||
mixedChannelWarningDetails.value = buildMixedChannelDetails(opts.response)
|
||||
mixedChannelWarningRawMessage.value =
|
||||
opts.message || opts.response?.message || t('admin.accounts.failedToUpdate')
|
||||
mixedChannelWarningAction.value = opts.onConfirm
|
||||
showMixedChannelWarning.value = true
|
||||
}
|
||||
|
||||
const withAntigravityConfirmFlag = (payload: Record<string, unknown>) => {
|
||||
if (needsMixedChannelCheck() && antigravityMixedChannelConfirmed.value) {
|
||||
return {
|
||||
...payload,
|
||||
confirm_mixed_channel_risk: true
|
||||
}
|
||||
}
|
||||
const cloned = { ...payload }
|
||||
delete cloned.confirm_mixed_channel_risk
|
||||
return cloned
|
||||
}
|
||||
|
||||
const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise<void>): Promise<boolean> => {
|
||||
if (!needsMixedChannelCheck()) {
|
||||
return true
|
||||
}
|
||||
if (antigravityMixedChannelConfirmed.value) {
|
||||
return true
|
||||
}
|
||||
if (!props.account) {
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await adminAPI.accounts.checkMixedChannelRisk({
|
||||
platform: props.account.platform,
|
||||
group_ids: form.group_ids,
|
||||
account_id: props.account.id
|
||||
})
|
||||
if (!result.has_risk) {
|
||||
return true
|
||||
}
|
||||
openMixedChannelDialog({
|
||||
response: result,
|
||||
onConfirm: async () => {
|
||||
antigravityMixedChannelConfirmed.value = true
|
||||
await onConfirm()
|
||||
}
|
||||
})
|
||||
return false
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const formatDateTimeLocal = formatDateTimeLocalInput
|
||||
const parseDateTimeLocal = parseDateTimeLocalInput
|
||||
|
||||
// Methods
|
||||
const handleClose = () => {
|
||||
antigravityMixedChannelConfirmed.value = false
|
||||
clearMixedChannelDialog()
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const submitUpdateAccount = async (accountID: number, updatePayload: Record<string, unknown>) => {
|
||||
submitting.value = true
|
||||
try {
|
||||
const updatedAccount = await adminAPI.accounts.update(accountID, withAntigravityConfirmFlag(updatePayload))
|
||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||
emit('updated', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning' && needsMixedChannelCheck()) {
|
||||
openMixedChannelDialog({
|
||||
message: error.response?.data?.message,
|
||||
onConfirm: async () => {
|
||||
antigravityMixedChannelConfirmed.value = true
|
||||
await submitUpdateAccount(accountID, updatePayload)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!props.account) return
|
||||
const accountID = props.account.id
|
||||
|
||||
submitting.value = true
|
||||
const updatePayload: Record<string, unknown> = { ...form }
|
||||
try {
|
||||
// 后端期望 proxy_id: 0 表示清除代理,而不是 null
|
||||
@@ -1768,7 +1862,6 @@ const handleSubmit = async () => {
|
||||
newCredentials.api_key = currentCredentials.api_key
|
||||
} else {
|
||||
appStore.showError(t('admin.accounts.apiKeyIsRequired'))
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1789,11 +1882,8 @@ const handleSubmit = async () => {
|
||||
}
|
||||
|
||||
// Add intercept warmup requests setting
|
||||
if (interceptWarmupRequests.value) {
|
||||
newCredentials.intercept_warmup_requests = true
|
||||
}
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1808,8 +1898,10 @@ const handleSubmit = async () => {
|
||||
newCredentials.api_key = editApiKey.value.trim()
|
||||
}
|
||||
|
||||
// Add intercept warmup requests setting
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1819,13 +1911,8 @@ const handleSubmit = async () => {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
if (interceptWarmupRequests.value) {
|
||||
newCredentials.intercept_warmup_requests = true
|
||||
} else {
|
||||
delete newCredentials.intercept_warmup_requests
|
||||
}
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
submitting.value = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1955,52 +2042,36 @@ const handleSubmit = async () => {
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
|
||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||
emit('updated', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
// Handle 409 mixed_channel_warning - show confirmation dialog
|
||||
if (error.response?.status === 409 && error.response?.data?.error === 'mixed_channel_warning') {
|
||||
const details = error.response.data.details || {}
|
||||
mixedChannelWarningDetails.value = {
|
||||
groupName: details.group_name || 'Unknown',
|
||||
currentPlatform: details.current_platform || 'Unknown',
|
||||
otherPlatform: details.other_platform || 'Unknown'
|
||||
}
|
||||
pendingUpdatePayload.value = updatePayload
|
||||
showMixedChannelWarning.value = true
|
||||
} else {
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||
const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => {
|
||||
await submitUpdateAccount(accountID, updatePayload)
|
||||
})
|
||||
if (!canContinue) {
|
||||
return
|
||||
}
|
||||
} finally {
|
||||
submitting.value = false
|
||||
|
||||
await submitUpdateAccount(accountID, updatePayload)
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||
}
|
||||
}
|
||||
|
||||
// Handle mixed channel warning confirmation
|
||||
const handleMixedChannelConfirm = async () => {
|
||||
showMixedChannelWarning.value = false
|
||||
if (pendingUpdatePayload.value && props.account) {
|
||||
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
|
||||
submitting.value = true
|
||||
try {
|
||||
const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
|
||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||
emit('updated', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||
} finally {
|
||||
submitting.value = false
|
||||
pendingUpdatePayload.value = null
|
||||
}
|
||||
const action = mixedChannelWarningAction.value
|
||||
if (!action) {
|
||||
clearMixedChannelDialog()
|
||||
return
|
||||
}
|
||||
clearMixedChannelDialog()
|
||||
submitting.value = true
|
||||
try {
|
||||
await action()
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleMixedChannelCancel = () => {
|
||||
showMixedChannelWarning.value = false
|
||||
pendingUpdatePayload.value = null
|
||||
mixedChannelWarningDetails.value = null
|
||||
clearMixedChannelDialog()
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -59,6 +59,17 @@
|
||||
t(getOAuthKey('sessionTokenAuth'))
|
||||
}}</span>
|
||||
</label>
|
||||
<label v-if="showAccessTokenOption" class="flex cursor-pointer items-center gap-2">
|
||||
<input
|
||||
v-model="inputMethod"
|
||||
type="radio"
|
||||
value="access_token"
|
||||
class="text-blue-600 focus:ring-blue-500"
|
||||
/>
|
||||
<span class="text-sm text-blue-900 dark:text-blue-200">{{
|
||||
t('admin.accounts.oauth.openai.accessTokenAuth', '手动输入 AT')
|
||||
}}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -227,6 +238,63 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Access Token Input (Sora) -->
|
||||
<div v-if="inputMethod === 'access_token'" class="space-y-4">
|
||||
<div
|
||||
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
|
||||
>
|
||||
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
|
||||
{{ t('admin.accounts.oauth.openai.accessTokenDesc', '直接粘贴 Access Token 创建账号,无需 OAuth 授权流程。支持批量导入(每行一个)。') }}
|
||||
</p>
|
||||
|
||||
<div class="mb-4">
|
||||
<label
|
||||
class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
<Icon name="key" size="sm" class="text-blue-500" />
|
||||
Access Token
|
||||
<span
|
||||
v-if="parsedAccessTokenCount > 1"
|
||||
class="rounded-full bg-blue-500 px-2 py-0.5 text-xs text-white"
|
||||
>
|
||||
{{ t('admin.accounts.oauth.keysCount', { count: parsedAccessTokenCount }) }}
|
||||
</span>
|
||||
</label>
|
||||
<textarea
|
||||
v-model="accessTokenInput"
|
||||
rows="3"
|
||||
class="input w-full resize-y font-mono text-sm"
|
||||
:placeholder="t('admin.accounts.oauth.openai.accessTokenPlaceholder', '粘贴 Access Token,每行一个')"
|
||||
></textarea>
|
||||
<p
|
||||
v-if="parsedAccessTokenCount > 1"
|
||||
class="mt-1 text-xs text-blue-600 dark:text-blue-400"
|
||||
>
|
||||
{{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedAccessTokenCount }) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="error"
|
||||
class="mb-4 rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30"
|
||||
>
|
||||
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">
|
||||
{{ error }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="btn btn-primary w-full"
|
||||
:disabled="loading || !accessTokenInput.trim()"
|
||||
@click="handleImportAccessToken"
|
||||
>
|
||||
<Icon name="sparkles" size="sm" class="mr-2" />
|
||||
{{ t('admin.accounts.oauth.openai.importAccessToken', '导入 Access Token') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Cookie Auto-Auth Form -->
|
||||
<div v-if="inputMethod === 'cookie'" class="space-y-4">
|
||||
<div
|
||||
@@ -618,6 +686,7 @@ interface Props {
|
||||
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
||||
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
|
||||
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
|
||||
showAccessTokenOption?: boolean // Whether to show access token input option (Sora only)
|
||||
platform?: AccountPlatform // Platform type for different UI/text
|
||||
showProjectId?: boolean // New prop to control project ID visibility
|
||||
}
|
||||
@@ -634,6 +703,7 @@ const props = withDefaults(defineProps<Props>(), {
|
||||
showCookieOption: true,
|
||||
showRefreshTokenOption: false,
|
||||
showSessionTokenOption: false,
|
||||
showAccessTokenOption: false,
|
||||
platform: 'anthropic',
|
||||
showProjectId: true
|
||||
})
|
||||
@@ -644,6 +714,7 @@ const emit = defineEmits<{
|
||||
'cookie-auth': [sessionKey: string]
|
||||
'validate-refresh-token': [refreshToken: string]
|
||||
'validate-session-token': [sessionToken: string]
|
||||
'import-access-token': [accessToken: string]
|
||||
'update:inputMethod': [method: AuthInputMethod]
|
||||
}>()
|
||||
|
||||
@@ -683,12 +754,13 @@ const authCodeInput = ref('')
|
||||
const sessionKeyInput = ref('')
|
||||
const refreshTokenInput = ref('')
|
||||
const sessionTokenInput = ref('')
|
||||
const accessTokenInput = ref('')
|
||||
const showHelpDialog = ref(false)
|
||||
const oauthState = ref('')
|
||||
const projectId = ref('')
|
||||
|
||||
// Computed: show method selection when either cookie or refresh token option is enabled
|
||||
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption)
|
||||
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption)
|
||||
|
||||
// Clipboard
|
||||
const { copied, copyToClipboard } = useClipboard()
|
||||
@@ -716,6 +788,13 @@ const parsedSessionTokenCount = computed(() => {
|
||||
.filter((st) => st).length
|
||||
})
|
||||
|
||||
const parsedAccessTokenCount = computed(() => {
|
||||
return accessTokenInput.value
|
||||
.split('\n')
|
||||
.map((at) => at.trim())
|
||||
.filter((at) => at).length
|
||||
})
|
||||
|
||||
// Watchers
|
||||
watch(inputMethod, (newVal) => {
|
||||
emit('update:inputMethod', newVal)
|
||||
@@ -789,6 +868,12 @@ const handleValidateSessionToken = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const handleImportAccessToken = () => {
|
||||
if (accessTokenInput.value.trim()) {
|
||||
emit('import-access-token', accessTokenInput.value.trim())
|
||||
}
|
||||
}
|
||||
|
||||
// Expose methods and state
|
||||
defineExpose({
|
||||
authCode: authCodeInput,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { applyInterceptWarmup } from '../credentialsBuilder'
|
||||
|
||||
describe('applyInterceptWarmup', () => {
|
||||
it('create + enabled=true: should set intercept_warmup_requests to true', () => {
|
||||
const creds: Record<string, unknown> = { access_token: 'tok' }
|
||||
applyInterceptWarmup(creds, true, 'create')
|
||||
expect(creds.intercept_warmup_requests).toBe(true)
|
||||
})
|
||||
|
||||
it('create + enabled=false: should not add the field', () => {
|
||||
const creds: Record<string, unknown> = { access_token: 'tok' }
|
||||
applyInterceptWarmup(creds, false, 'create')
|
||||
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||
})
|
||||
|
||||
it('edit + enabled=true: should set intercept_warmup_requests to true', () => {
|
||||
const creds: Record<string, unknown> = { api_key: 'sk' }
|
||||
applyInterceptWarmup(creds, true, 'edit')
|
||||
expect(creds.intercept_warmup_requests).toBe(true)
|
||||
})
|
||||
|
||||
it('edit + enabled=false + field exists: should delete the field', () => {
|
||||
const creds: Record<string, unknown> = { api_key: 'sk', intercept_warmup_requests: true }
|
||||
applyInterceptWarmup(creds, false, 'edit')
|
||||
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||
})
|
||||
|
||||
it('edit + enabled=false + field absent: should not throw', () => {
|
||||
const creds: Record<string, unknown> = { api_key: 'sk' }
|
||||
applyInterceptWarmup(creds, false, 'edit')
|
||||
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||
})
|
||||
|
||||
it('should not affect other fields', () => {
|
||||
const creds: Record<string, unknown> = {
|
||||
api_key: 'sk',
|
||||
base_url: 'url',
|
||||
intercept_warmup_requests: true
|
||||
}
|
||||
applyInterceptWarmup(creds, false, 'edit')
|
||||
expect(creds.api_key).toBe('sk')
|
||||
expect(creds.base_url).toBe('url')
|
||||
expect('intercept_warmup_requests' in creds).toBe(false)
|
||||
})
|
||||
})
|
||||
11
frontend/src/components/account/credentialsBuilder.ts
Normal file
11
frontend/src/components/account/credentialsBuilder.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export function applyInterceptWarmup(
|
||||
credentials: Record<string, unknown>,
|
||||
enabled: boolean,
|
||||
mode: 'create' | 'edit'
|
||||
): void {
|
||||
if (enabled) {
|
||||
credentials.intercept_warmup_requests = true
|
||||
} else if (mode === 'edit') {
|
||||
delete credentials.intercept_warmup_requests
|
||||
}
|
||||
}
|
||||
@@ -160,6 +160,7 @@
|
||||
<button type="button" @click="$emit('reset')" class="btn btn-secondary">
|
||||
{{ t('common.reset') }}
|
||||
</button>
|
||||
<slot name="after-reset" />
|
||||
<button type="button" @click="$emit('cleanup')" class="btn btn-danger">
|
||||
{{ t('admin.usage.cleanup.button') }}
|
||||
</button>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<template>
|
||||
<div class="card overflow-hidden">
|
||||
<div class="overflow-auto">
|
||||
<DataTable :columns="cols" :data="data" :loading="loading">
|
||||
<DataTable :columns="columns" :data="data" :loading="loading">
|
||||
<template #cell-user="{ row }">
|
||||
<div class="text-sm">
|
||||
<span class="font-medium text-gray-900 dark:text-white">{{ row.user?.email || '-' }}</span>
|
||||
@@ -123,7 +123,7 @@
|
||||
</template>
|
||||
|
||||
<template #cell-user_agent="{ row }">
|
||||
<span v-if="row.user_agent" class="text-sm text-gray-600 dark:text-gray-400 block max-w-[320px] whitespace-normal break-all" :title="row.user_agent">{{ formatUserAgent(row.user_agent) }}</span>
|
||||
<span v-if="row.user_agent" class="text-sm text-gray-600 dark:text-gray-400 block max-w-[320px] truncate" :title="row.user_agent">{{ formatUserAgent(row.user_agent) }}</span>
|
||||
<span v-else class="text-sm text-gray-400 dark:text-gray-500">-</span>
|
||||
</template>
|
||||
|
||||
@@ -268,7 +268,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { formatDateTime, formatReasoningEffort } from '@/utils/format'
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
@@ -276,7 +276,7 @@ import EmptyState from '@/components/common/EmptyState.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import type { AdminUsageLog } from '@/types'
|
||||
|
||||
defineProps(['data', 'loading'])
|
||||
defineProps(['data', 'loading', 'columns'])
|
||||
const { t } = useI18n()
|
||||
|
||||
// Tooltip state - cost
|
||||
@@ -289,23 +289,6 @@ const tokenTooltipVisible = ref(false)
|
||||
const tokenTooltipPosition = ref({ x: 0, y: 0 })
|
||||
const tokenTooltipData = ref<AdminUsageLog | null>(null)
|
||||
|
||||
const cols = computed(() => [
|
||||
{ key: 'user', label: t('admin.usage.user'), sortable: false },
|
||||
{ key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false },
|
||||
{ key: 'account', label: t('admin.usage.account'), sortable: false },
|
||||
{ key: 'model', label: t('usage.model'), sortable: true },
|
||||
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
|
||||
{ key: 'group', label: t('admin.usage.group'), sortable: false },
|
||||
{ key: 'stream', label: t('usage.type'), sortable: false },
|
||||
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
||||
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
||||
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
||||
{ key: 'duration', label: t('usage.duration'), sortable: false },
|
||||
{ key: 'created_at', label: t('usage.time'), sortable: true },
|
||||
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false },
|
||||
{ key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false }
|
||||
])
|
||||
|
||||
const formatCacheTokens = (tokens: number): string => {
|
||||
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
|
||||
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
|
||||
|
||||
@@ -534,8 +534,104 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
|
||||
}
|
||||
}
|
||||
const openaiModels = {
|
||||
'gpt-5-codex': {
|
||||
name: 'GPT-5 Codex',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.1-codex': {
|
||||
name: 'GPT-5.1 Codex',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.1-codex-max': {
|
||||
name: 'GPT-5.1 Codex Max',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.1-codex-mini': {
|
||||
name: 'GPT-5.1 Codex Mini',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.2': {
|
||||
name: 'GPT-5.2',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {},
|
||||
xhigh: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.3-codex-spark': {
|
||||
name: 'GPT-5.3 Codex Spark',
|
||||
limit: {
|
||||
context: 128000,
|
||||
output: 32000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {},
|
||||
xhigh: {}
|
||||
}
|
||||
},
|
||||
'gpt-5.3-codex': {
|
||||
name: 'GPT-5.3 Codex',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
@@ -548,6 +644,10 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
|
||||
},
|
||||
'gpt-5.2-codex': {
|
||||
name: 'GPT-5.2 Codex',
|
||||
limit: {
|
||||
context: 400000,
|
||||
output: 128000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
@@ -557,30 +657,266 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
|
||||
high: {},
|
||||
xhigh: {}
|
||||
}
|
||||
},
|
||||
'codex-mini-latest': {
|
||||
name: 'Codex Mini',
|
||||
limit: {
|
||||
context: 200000,
|
||||
output: 100000
|
||||
},
|
||||
options: {
|
||||
store: false
|
||||
},
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
const geminiModels = {
|
||||
'gemini-2.0-flash': { name: 'Gemini 2.0 Flash' },
|
||||
'gemini-2.5-flash': { name: 'Gemini 2.5 Flash' },
|
||||
'gemini-2.5-pro': { name: 'Gemini 2.5 Pro' },
|
||||
'gemini-3-flash-preview': { name: 'Gemini 3 Flash Preview' },
|
||||
'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' }
|
||||
'gemini-2.0-flash': {
|
||||
name: 'Gemini 2.0 Flash',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
}
|
||||
},
|
||||
'gemini-2.5-flash': {
|
||||
name: 'Gemini 2.5 Flash',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
}
|
||||
},
|
||||
'gemini-2.5-pro': {
|
||||
name: 'Gemini 2.5 Pro',
|
||||
limit: {
|
||||
context: 2097152,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-3-flash-preview': {
|
||||
name: 'Gemini 3 Flash Preview',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
}
|
||||
},
|
||||
'gemini-3-pro-preview': {
|
||||
name: 'Gemini 3 Pro Preview',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-3.1-pro-preview': {
|
||||
name: 'Gemini 3.1 Pro Preview',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const antigravityGeminiModels = {
|
||||
'gemini-2.5-flash': { name: 'Gemini 2.5 Flash' },
|
||||
'gemini-2.5-flash-lite': { name: 'Gemini 2.5 Flash Lite' },
|
||||
'gemini-2.5-flash-thinking': { name: 'Gemini 2.5 Flash Thinking' },
|
||||
'gemini-3-flash': { name: 'Gemini 3 Flash' },
|
||||
'gemini-3-pro-low': { name: 'Gemini 3 Pro Low' },
|
||||
'gemini-3-pro-high': { name: 'Gemini 3 Pro High' },
|
||||
'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' },
|
||||
'gemini-3-pro-image': { name: 'Gemini 3 Pro Image' }
|
||||
'gemini-2.5-flash': {
|
||||
name: 'Gemini 2.5 Flash',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'disable'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-2.5-flash-lite': {
|
||||
name: 'Gemini 2.5 Flash Lite',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-2.5-flash-thinking': {
|
||||
name: 'Gemini 2.5 Flash (Thinking)',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-3-flash': {
|
||||
name: 'Gemini 3 Flash',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-3.1-pro-low': {
|
||||
name: 'Gemini 3.1 Pro Low',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-3.1-pro-high': {
|
||||
name: 'Gemini 3.1 Pro High',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'gemini-3.1-flash-image': {
|
||||
name: 'Gemini 3.1 Flash Image',
|
||||
limit: {
|
||||
context: 1048576,
|
||||
output: 65536
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image'],
|
||||
output: ['image']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const claudeModels = {
|
||||
'claude-opus-4-5-thinking': { name: 'Claude Opus 4.5 Thinking' },
|
||||
'claude-sonnet-4-5-thinking': { name: 'Claude Sonnet 4.5 Thinking' },
|
||||
'claude-sonnet-4-5': { name: 'Claude Sonnet 4.5' }
|
||||
'claude-opus-4-6-thinking': {
|
||||
name: 'Claude 4.6 Opus (Thinking)',
|
||||
limit: {
|
||||
context: 200000,
|
||||
output: 128000
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
},
|
||||
'claude-sonnet-4-6': {
|
||||
name: 'Claude 4.6 Sonnet',
|
||||
limit: {
|
||||
context: 200000,
|
||||
output: 64000
|
||||
},
|
||||
modalities: {
|
||||
input: ['text', 'image', 'pdf'],
|
||||
output: ['text']
|
||||
},
|
||||
options: {
|
||||
thinking: {
|
||||
budgetTokens: 24576,
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (platform === 'gemini') {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
|
||||
export type AddMethod = 'oauth' | 'setup-token'
|
||||
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token'
|
||||
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' | 'access_token'
|
||||
|
||||
export interface OAuthState {
|
||||
authUrl: string
|
||||
|
||||
@@ -24,6 +24,8 @@ const openaiModels = [
|
||||
// GPT-5.2 系列
|
||||
'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest',
|
||||
'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
|
||||
// GPT-5.3 系列
|
||||
'gpt-5.3-codex', 'gpt-5.3-codex-spark',
|
||||
'chatgpt-4o-latest',
|
||||
'gpt-4o-audio-preview', 'gpt-4o-realtime-preview'
|
||||
]
|
||||
@@ -75,7 +77,9 @@ const soraModels = [
|
||||
const antigravityModels = [
|
||||
// Claude 4.5+ 系列
|
||||
'claude-opus-4-6',
|
||||
'claude-opus-4-6-thinking',
|
||||
'claude-opus-4-5-thinking',
|
||||
'claude-sonnet-4-6',
|
||||
'claude-sonnet-4-5',
|
||||
'claude-sonnet-4-5-thinking',
|
||||
// Gemini 2.5 系列
|
||||
@@ -87,7 +91,10 @@ const antigravityModels = [
|
||||
'gemini-3-flash',
|
||||
'gemini-3-pro-high',
|
||||
'gemini-3-pro-low',
|
||||
'gemini-3-pro-image',
|
||||
// Gemini 3.1 系列
|
||||
'gemini-3.1-pro-high',
|
||||
'gemini-3.1-pro-low',
|
||||
'gemini-3.1-flash-image',
|
||||
// 其他
|
||||
'gpt-oss-120b-medium',
|
||||
'tab_flash_lite_preview'
|
||||
@@ -287,15 +294,25 @@ const antigravityPresetMappings = [
|
||||
{ label: 'Sonnet→Sonnet', from: 'claude-sonnet-*', to: 'claude-sonnet-4-5', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
|
||||
{ label: 'Opus→Opus', from: 'claude-opus-*', to: 'claude-opus-4-6-thinking', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||
{ label: 'Haiku→Sonnet', from: 'claude-haiku-*', to: 'claude-sonnet-4-5', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||
{ label: 'Sonnet4→4.6', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-6', color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' },
|
||||
{ label: 'Sonnet4.5→4.6', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||
{ label: 'Sonnet3.5→4.6', from: 'claude-3-5-sonnet-20241022', to: 'claude-sonnet-4-6', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' },
|
||||
{ label: 'Opus4.5→4.6', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-6-thinking', color: 'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400' },
|
||||
// Gemini 3→3.1 映射
|
||||
{ label: '3-Pro-Preview→3.1-Pro-High', from: 'gemini-3-pro-preview', to: 'gemini-3.1-pro-high', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' },
|
||||
{ label: '3-Pro-High→3.1-Pro-High', from: 'gemini-3-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||
{ label: '3-Pro-Low→3.1-Pro-Low', from: 'gemini-3-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||
{ label: '3.1-Pro-High透传', from: 'gemini-3.1-pro-high', to: 'gemini-3.1-pro-high', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||
{ label: '3.1-Pro-Low透传', from: 'gemini-3.1-pro-low', to: 'gemini-3.1-pro-low', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||
// Gemini 通配符映射
|
||||
{ label: 'Gemini 3→Flash', from: 'gemini-3*', to: 'gemini-3-flash', color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400' },
|
||||
{ label: 'Gemini 2.5→Flash', from: 'gemini-2.5*', to: 'gemini-2.5-flash', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
|
||||
{ label: '3-Flash透传', from: 'gemini-3-flash', to: 'gemini-3-flash', color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400' },
|
||||
{ label: '2.5-Flash-Lite透传', from: 'gemini-2.5-flash-lite', to: 'gemini-2.5-flash-lite', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||
// 精确映射
|
||||
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
|
||||
]
|
||||
|
||||
|
||||
@@ -68,6 +68,14 @@ export async function setLocale(locale: string): Promise<void> {
|
||||
i18n.global.locale.value = locale
|
||||
localStorage.setItem(LOCALE_KEY, locale)
|
||||
document.documentElement.setAttribute('lang', locale)
|
||||
|
||||
// 同步更新浏览器页签标题,使其跟随语言切换
|
||||
const { resolveDocumentTitle } = await import('@/router/title')
|
||||
const { default: router } = await import('@/router')
|
||||
const { useAppStore } = await import('@/stores/app')
|
||||
const route = router.currentRoute.value
|
||||
const appStore = useAppStore()
|
||||
document.title = resolveDocumentTitle(route.meta.title, appStore.siteName, route.meta.titleKey as string)
|
||||
}
|
||||
|
||||
export function getLocale(): LocaleCode {
|
||||
|
||||
@@ -1133,7 +1133,7 @@ export default {
|
||||
},
|
||||
imagePricing: {
|
||||
title: 'Image Generation Pricing',
|
||||
description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
|
||||
description: 'Configure pricing for image generation models. Leave empty to use default prices.'
|
||||
},
|
||||
soraPricing: {
|
||||
title: 'Sora Per-Request Pricing',
|
||||
@@ -1505,7 +1505,8 @@ export default {
|
||||
partialSuccess: 'Partially updated: {success} succeeded, {failed} failed',
|
||||
failed: 'Bulk update failed',
|
||||
noSelection: 'Please select accounts to edit',
|
||||
noFieldsSelected: 'Select at least one field to update'
|
||||
noFieldsSelected: 'Select at least one field to update',
|
||||
mixedPlatformWarning: 'Selected accounts span multiple platforms ({platforms}). Model mapping presets shown are combined — ensure mappings are appropriate for each platform.'
|
||||
},
|
||||
bulkDeleteTitle: 'Bulk Delete Accounts',
|
||||
bulkDeleteConfirm: 'Delete the selected {count} account(s)? This action cannot be undone.',
|
||||
@@ -2046,8 +2047,8 @@ export default {
|
||||
geminiFlashDaily: 'Flash',
|
||||
gemini3Pro: 'G3P',
|
||||
gemini3Flash: 'G3F',
|
||||
gemini3Image: 'G3I',
|
||||
claude45: 'C4.5'
|
||||
gemini3Image: 'GImage',
|
||||
claude: 'Claude'
|
||||
},
|
||||
tier: {
|
||||
free: 'Free',
|
||||
|
||||
@@ -1220,7 +1220,7 @@ export default {
|
||||
},
|
||||
imagePricing: {
|
||||
title: '图片生成计费',
|
||||
description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
|
||||
description: '配置图片生成模型的图片生成价格,留空则使用默认价格'
|
||||
},
|
||||
soraPricing: {
|
||||
title: 'Sora 按次计费',
|
||||
@@ -1582,8 +1582,8 @@ export default {
|
||||
geminiFlashDaily: 'Flash',
|
||||
gemini3Pro: 'G3P',
|
||||
gemini3Flash: 'G3F',
|
||||
gemini3Image: 'G3I',
|
||||
claude45: 'C4.5'
|
||||
gemini3Image: 'GImage',
|
||||
claude: 'Claude'
|
||||
},
|
||||
tier: {
|
||||
free: 'Free',
|
||||
@@ -1652,7 +1652,8 @@ export default {
|
||||
partialSuccess: '部分更新成功:成功 {success} 个,失败 {failed} 个',
|
||||
failed: '批量更新失败',
|
||||
noSelection: '请选择要编辑的账号',
|
||||
noFieldsSelected: '请至少选择一个要更新的字段'
|
||||
noFieldsSelected: '请至少选择一个要更新的字段',
|
||||
mixedPlatformWarning: '所选账号跨越多个平台({platforms})。显示的模型映射预设为合并结果——请确保映射对每个平台都适用。'
|
||||
},
|
||||
bulkDeleteTitle: '批量删除账号',
|
||||
bulkDeleteConfirm: '确定要删除选中的 {count} 个账号吗?此操作无法撤销。',
|
||||
|
||||
@@ -41,7 +41,8 @@ const routes: RouteRecordRaw[] = [
|
||||
component: () => import('@/views/auth/LoginView.vue'),
|
||||
meta: {
|
||||
requiresAuth: false,
|
||||
title: 'Login'
|
||||
title: 'Login',
|
||||
titleKey: 'common.login'
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -50,7 +51,8 @@ const routes: RouteRecordRaw[] = [
|
||||
component: () => import('@/views/auth/RegisterView.vue'),
|
||||
meta: {
|
||||
requiresAuth: false,
|
||||
title: 'Register'
|
||||
title: 'Register',
|
||||
titleKey: 'auth.createAccount'
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -86,7 +88,8 @@ const routes: RouteRecordRaw[] = [
|
||||
component: () => import('@/views/auth/ForgotPasswordView.vue'),
|
||||
meta: {
|
||||
requiresAuth: false,
|
||||
title: 'Forgot Password'
|
||||
title: 'Forgot Password',
|
||||
titleKey: 'auth.forgotPasswordTitle'
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -390,7 +393,7 @@ router.beforeEach((to, _from, next) => {
|
||||
|
||||
// Set page title
|
||||
const appStore = useAppStore()
|
||||
document.title = resolveDocumentTitle(to.meta.title, appStore.siteName)
|
||||
document.title = resolveDocumentTitle(to.meta.title, appStore.siteName, to.meta.titleKey as string)
|
||||
|
||||
// Check if route requires authentication
|
||||
const requiresAuth = to.meta.requiresAuth !== false // Default to true
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import { i18n } from '@/i18n'
|
||||
|
||||
/**
|
||||
* 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。
|
||||
* 优先使用 titleKey 通过 i18n 翻译,fallback 到静态 routeTitle。
|
||||
*/
|
||||
export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string {
|
||||
export function resolveDocumentTitle(routeTitle: unknown, siteName?: string, titleKey?: string): string {
|
||||
const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'Sub2API'
|
||||
|
||||
if (typeof titleKey === 'string' && titleKey.trim()) {
|
||||
const translated = i18n.global.t(titleKey)
|
||||
if (translated && translated !== titleKey) {
|
||||
return `${translated} - ${normalizedSiteName}`
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof routeTitle === 'string' && routeTitle.trim()) {
|
||||
return `${routeTitle.trim()} - ${normalizedSiteName}`
|
||||
}
|
||||
|
||||
@@ -581,6 +581,7 @@ export interface GeminiCredentials {
|
||||
token_type?: string
|
||||
scope?: string
|
||||
expires_at?: string
|
||||
model_mapping?: Record<string, string>
|
||||
}
|
||||
|
||||
export interface TempUnschedulableRule {
|
||||
@@ -766,6 +767,26 @@ export interface UpdateAccountRequest {
|
||||
confirm_mixed_channel_risk?: boolean
|
||||
}
|
||||
|
||||
export interface CheckMixedChannelRequest {
|
||||
platform: AccountPlatform
|
||||
group_ids: number[]
|
||||
account_id?: number
|
||||
}
|
||||
|
||||
export interface MixedChannelWarningDetails {
|
||||
group_id: number
|
||||
group_name: string
|
||||
current_platform: string
|
||||
other_platform: string
|
||||
}
|
||||
|
||||
export interface CheckMixedChannelResponse {
|
||||
has_risk: boolean
|
||||
error?: string
|
||||
message?: string
|
||||
details?: MixedChannelWarningDetails
|
||||
}
|
||||
|
||||
export interface CreateProxyRequest {
|
||||
name: string
|
||||
protocol: ProxyProtocol
|
||||
|
||||
@@ -259,7 +259,7 @@
|
||||
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
|
||||
<SyncFromCrsModal :show="showSync" @close="showSync = false" @synced="reload" />
|
||||
<ImportDataModal :show="showImportData" @close="showImportData = false" @imported="handleDataImported" />
|
||||
<BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" />
|
||||
<BulkEditAccountModal :show="showBulkEdit" :account-ids="selIds" :selected-platforms="selPlatforms" :proxies="proxies" :groups="groups" @close="showBulkEdit = false" @updated="handleBulkUpdated" />
|
||||
<TempUnschedStatusModal :show="showTempUnsched" :account="tempUnschedAcc" @close="showTempUnsched = false" @reset="handleTempUnschedReset" />
|
||||
<ConfirmDialog :show="showDeleteDialog" :title="t('admin.accounts.deleteAccount')" :message="t('admin.accounts.deleteConfirm', { name: deletingAcc?.name })" :confirm-text="t('common.delete')" :cancel-text="t('common.cancel')" :danger="true" @confirm="confirmDelete" @cancel="showDeleteDialog = false" />
|
||||
<ConfirmDialog :show="showExportDataDialog" :title="t('admin.accounts.dataExport')" :message="t('admin.accounts.dataExportConfirmMessage')" :confirm-text="t('admin.accounts.dataExportConfirm')" :cancel-text="t('common.cancel')" @confirm="handleExportData" @cancel="showExportDataDialog = false">
|
||||
@@ -303,7 +303,7 @@ import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import ErrorPassthroughRulesModal from '@/components/admin/ErrorPassthroughRulesModal.vue'
|
||||
import { formatDateTime, formatRelativeTime } from '@/utils/format'
|
||||
import type { Account, Proxy, AdminGroup } from '@/types'
|
||||
import type { Account, AccountPlatform, Proxy, AdminGroup } from '@/types'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
@@ -312,6 +312,14 @@ const authStore = useAuthStore()
|
||||
const proxies = ref<Proxy[]>([])
|
||||
const groups = ref<AdminGroup[]>([])
|
||||
const selIds = ref<number[]>([])
|
||||
const selPlatforms = computed<AccountPlatform[]>(() => {
|
||||
const platforms = new Set(
|
||||
accounts.value
|
||||
.filter(a => selIds.value.includes(a.id))
|
||||
.map(a => a.platform)
|
||||
)
|
||||
return [...platforms]
|
||||
})
|
||||
const showCreate = ref(false)
|
||||
const showEdit = ref(false)
|
||||
const showSync = ref(false)
|
||||
|
||||
@@ -459,7 +459,7 @@
|
||||
step="0.001"
|
||||
min="0"
|
||||
class="input"
|
||||
placeholder="0.134"
|
||||
placeholder="0.201"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
@@ -1139,7 +1139,7 @@
|
||||
step="0.001"
|
||||
min="0"
|
||||
class="input"
|
||||
placeholder="0.134"
|
||||
placeholder="0.201"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
|
||||
@@ -17,8 +17,43 @@
|
||||
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
|
||||
</div>
|
||||
</div>
|
||||
<UsageFilters v-model="filters" v-model:startDate="startDate" v-model:endDate="endDate" :exporting="exporting" @change="applyFilters" @refresh="refreshData" @reset="resetFilters" @cleanup="openCleanupDialog" @export="exportToExcel" />
|
||||
<UsageTable :data="usageLogs" :loading="loading" />
|
||||
<UsageFilters v-model="filters" v-model:startDate="startDate" v-model:endDate="endDate" :exporting="exporting" @change="applyFilters" @refresh="refreshData" @reset="resetFilters" @cleanup="openCleanupDialog" @export="exportToExcel">
|
||||
<template #after-reset>
|
||||
<div class="relative" ref="columnDropdownRef">
|
||||
<button
|
||||
@click="showColumnDropdown = !showColumnDropdown"
|
||||
class="btn btn-secondary px-2 md:px-3"
|
||||
:title="t('admin.users.columnSettings')"
|
||||
>
|
||||
<svg class="h-4 w-4 md:mr-1.5" fill="none" stroke="currentColor" viewBox="0 0 24 24" stroke-width="1.5">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M9 4.5v15m6-15v15m-10.875 0h15.75c.621 0 1.125-.504 1.125-1.125V5.625c0-.621-.504-1.125-1.125-1.125H4.125C3.504 4.5 3 5.004 3 5.625v12.75c0 .621.504 1.125 1.125 1.125z" />
|
||||
</svg>
|
||||
<span class="hidden md:inline">{{ t('admin.users.columnSettings') }}</span>
|
||||
</button>
|
||||
<div
|
||||
v-if="showColumnDropdown"
|
||||
class="absolute right-0 top-full z-50 mt-1 max-h-80 w-48 overflow-y-auto rounded-lg border border-gray-200 bg-white py-1 shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
||||
>
|
||||
<button
|
||||
v-for="col in toggleableColumns"
|
||||
:key="col.key"
|
||||
@click="toggleColumn(col.key)"
|
||||
class="flex w-full items-center justify-between px-4 py-2 text-left text-sm text-gray-700 hover:bg-gray-100 dark:text-gray-300 dark:hover:bg-dark-700"
|
||||
>
|
||||
<span>{{ col.label }}</span>
|
||||
<Icon
|
||||
v-if="isColumnVisible(col.key)"
|
||||
name="check"
|
||||
size="sm"
|
||||
class="text-primary-500"
|
||||
:stroke-width="2"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</UsageFilters>
|
||||
<UsageTable :data="usageLogs" :loading="loading" :columns="visibleColumns" />
|
||||
<Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" />
|
||||
</div>
|
||||
</AppLayout>
|
||||
@@ -43,6 +78,7 @@ import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; impo
|
||||
import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue'
|
||||
import UsageCleanupDialog from '@/components/admin/usage/UsageCleanupDialog.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import type { AdminUsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
|
||||
|
||||
const { t } = useI18n()
|
||||
@@ -141,6 +177,77 @@ const exportToExcel = async () => {
|
||||
finally { if(exportAbortController === c) { exportAbortController = null; exporting.value = false; exportProgress.show = false } }
|
||||
}
|
||||
|
||||
onMounted(() => { loadLogs(); loadStats(); loadChartData() })
|
||||
onUnmounted(() => { abortController?.abort(); exportAbortController?.abort() })
|
||||
// Column visibility
|
||||
const ALWAYS_VISIBLE = ['user', 'created_at']
|
||||
const DEFAULT_HIDDEN_COLUMNS = ['reasoning_effort', 'user_agent']
|
||||
const HIDDEN_COLUMNS_KEY = 'usage-hidden-columns'
|
||||
|
||||
const allColumns = computed(() => [
|
||||
{ key: 'user', label: t('admin.usage.user'), sortable: false },
|
||||
{ key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false },
|
||||
{ key: 'account', label: t('admin.usage.account'), sortable: false },
|
||||
{ key: 'model', label: t('usage.model'), sortable: true },
|
||||
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
|
||||
{ key: 'group', label: t('admin.usage.group'), sortable: false },
|
||||
{ key: 'stream', label: t('usage.type'), sortable: false },
|
||||
{ key: 'tokens', label: t('usage.tokens'), sortable: false },
|
||||
{ key: 'cost', label: t('usage.cost'), sortable: false },
|
||||
{ key: 'first_token', label: t('usage.firstToken'), sortable: false },
|
||||
{ key: 'duration', label: t('usage.duration'), sortable: false },
|
||||
{ key: 'created_at', label: t('usage.time'), sortable: true },
|
||||
{ key: 'user_agent', label: t('usage.userAgent'), sortable: false },
|
||||
{ key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false }
|
||||
])
|
||||
|
||||
const hiddenColumns = reactive<Set<string>>(new Set())
|
||||
|
||||
const toggleableColumns = computed(() =>
|
||||
allColumns.value.filter(col => !ALWAYS_VISIBLE.includes(col.key))
|
||||
)
|
||||
|
||||
const visibleColumns = computed(() =>
|
||||
allColumns.value.filter(col =>
|
||||
ALWAYS_VISIBLE.includes(col.key) || !hiddenColumns.has(col.key)
|
||||
)
|
||||
)
|
||||
|
||||
const isColumnVisible = (key: string) => !hiddenColumns.has(key)
|
||||
|
||||
const toggleColumn = (key: string) => {
|
||||
if (hiddenColumns.has(key)) {
|
||||
hiddenColumns.delete(key)
|
||||
} else {
|
||||
hiddenColumns.add(key)
|
||||
}
|
||||
try {
|
||||
localStorage.setItem(HIDDEN_COLUMNS_KEY, JSON.stringify([...hiddenColumns]))
|
||||
} catch (e) {
|
||||
console.error('Failed to save columns:', e)
|
||||
}
|
||||
}
|
||||
|
||||
const loadSavedColumns = () => {
|
||||
try {
|
||||
const saved = localStorage.getItem(HIDDEN_COLUMNS_KEY)
|
||||
if (saved) {
|
||||
(JSON.parse(saved) as string[]).forEach(key => hiddenColumns.add(key))
|
||||
} else {
|
||||
DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
|
||||
}
|
||||
} catch {
|
||||
DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
|
||||
}
|
||||
}
|
||||
|
||||
const showColumnDropdown = ref(false)
|
||||
const columnDropdownRef = ref<HTMLElement | null>(null)
|
||||
|
||||
const handleColumnClickOutside = (event: MouseEvent) => {
|
||||
if (columnDropdownRef.value && !columnDropdownRef.value.contains(event.target as HTMLElement)) {
|
||||
showColumnDropdown.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => { loadLogs(); loadStats(); loadChartData(); loadSavedColumns(); document.addEventListener('click', handleColumnClickOutside) })
|
||||
onUnmounted(() => { abortController?.abort(); exportAbortController?.abort(); document.removeEventListener('click', handleColumnClickOutside) })
|
||||
</script>
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import { resolve } from 'path'
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': resolve(__dirname, 'src'),
|
||||
'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js'
|
||||
}
|
||||
},
|
||||
define: {
|
||||
__INTLIFY_JIT_COMPILATION__: true
|
||||
},
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'jsdom',
|
||||
@@ -37,8 +32,6 @@ export default defineConfig({
|
||||
lines: 80
|
||||
}
|
||||
}
|
||||
},
|
||||
setupFiles: ['./src/__tests__/setup.ts'],
|
||||
testTimeout: 10000
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user