mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
322 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04b4d7de3a | ||
|
|
b11cfdc11f | ||
|
|
cdddbec629 | ||
|
|
b3fe0506fb | ||
|
|
8ca0e2772e | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
5f6e929d61 | ||
|
|
4caa3c2701 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa | ||
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
25178cdbe1 | ||
|
|
a461538d58 | ||
|
|
b43ee62947 | ||
|
|
ebe6f418f3 | ||
|
|
391e79f8ee | ||
|
|
c7fcb7a84b | ||
|
|
87f4ed591e | ||
|
|
440d2e28ed | ||
|
|
6cb8980404 | ||
|
|
fe752bbd35 | ||
|
|
c74d451fa2 | ||
|
|
12d743fb35 | ||
|
|
6acb9f7910 | ||
|
|
eb6f5c6927 | ||
|
|
7ccb4c8ea3 | ||
|
|
4ce986d47d | ||
|
|
91ef085d7d | ||
|
|
97aaa24733 | ||
|
|
faf6441633 | ||
|
|
00c151b463 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 | ||
|
|
a2ae9f1f27 | ||
|
|
4cd6d86426 | ||
|
|
fa72f1947a | ||
|
|
9ee7d3935d | ||
|
|
1071fe0ac7 | ||
|
|
0be003377f | ||
|
|
ca3f497b56 | ||
|
|
034b84b707 | ||
|
|
1624523c4e | ||
|
|
313afe14ce | ||
|
|
01180b316f | ||
|
|
ee7d061001 | ||
|
|
60c5949a74 | ||
|
|
2ebbd4c94d | ||
|
|
785115c62b | ||
|
|
e643fc382c | ||
|
|
34aad82ac3 | ||
|
|
0c29468f90 | ||
|
|
9301dae63e | ||
|
|
2475d4a205 | ||
|
|
be75fc3474 | ||
|
|
785e049af3 | ||
|
|
be4e49e6d7 | ||
|
|
1307d604e7 | ||
|
|
45d57018eb | ||
|
|
03bf348530 | ||
|
|
cab60ef735 | ||
|
|
a3791104f9 | ||
|
|
2b3e40bb2a | ||
|
|
0c1dcad429 | ||
|
|
101ef0cf62 | ||
|
|
0debe0a80c | ||
|
|
d22e62ac8a | ||
|
|
1ee17383f8 | ||
|
|
b59c79c458 | ||
|
|
c28f691f32 | ||
|
|
a9ecd7bcc6 | ||
|
|
f89465fb39 | ||
|
|
440c3f46a7 | ||
|
|
c746964936 | ||
|
|
b2d6879b3f | ||
|
|
ce4095904e | ||
|
|
0d2dcbff11 | ||
|
|
02dca78dbe | ||
|
|
8df299b767 | ||
|
|
95cf59b2f6 | ||
|
|
a0f643da0e | ||
|
|
99331a5285 | ||
|
|
9245e197a8 | ||
|
|
5fa1d0978f | ||
|
|
0bb683a6f1 | ||
|
|
cb3958bac3 | ||
|
|
91ebf95efa | ||
|
|
7895dd65b2 | ||
|
|
09f8894906 | ||
|
|
45fdef9da7 | ||
|
|
704f256bd7 | ||
|
|
a6026e7ac4 | ||
|
|
1e03b2974a | ||
|
|
daa7c783b9 | ||
|
|
8a82a2a648 | ||
|
|
c3ac68af2a | ||
|
|
ec3897b981 | ||
|
|
62486cee37 | ||
|
|
7c5746ffbc | ||
|
|
47f7b0213b | ||
|
|
8df42f7aab | ||
|
|
d666e05a6d | ||
|
|
c37edf2de5 | ||
|
|
bb9af2465e | ||
|
|
7af00864b3 | ||
|
|
81903e87e3 | ||
|
|
0b96c7a65e | ||
|
|
34ccfe45ea | ||
|
|
d4231150a9 | ||
|
|
2268e93aec | ||
|
|
c976916b6d | ||
|
|
0bbd003a18 | ||
|
|
336a844712 | ||
|
|
51e7f262bd | ||
|
|
99663a3f20 | ||
|
|
4d88248091 | ||
|
|
c143367e15 | ||
|
|
ca44389baa | ||
|
|
05c2a65ef0 | ||
|
|
3f21d204a7 | ||
|
|
c848f950ce | ||
|
|
87bd765a57 | ||
|
|
49325a769e | ||
|
|
238d86f502 | ||
|
|
7064063230 | ||
|
|
b1de4352a8 | ||
|
|
7bf5c1cbcb | ||
|
|
411e24146d | ||
|
|
c7392fc80b | ||
|
|
c37c68a341 | ||
|
|
9230d3cbc9 | ||
|
|
19925e22d9 | ||
|
|
65e0c1b258 | ||
|
|
94bdde32bb | ||
|
|
14c80d26c6 | ||
|
|
9555a99d1c | ||
|
|
0e69895603 | ||
|
|
cc3cf1d70a | ||
|
|
3382d496e3 | ||
|
|
e0b4b00dc1 | ||
|
|
81d896bf78 | ||
|
|
ec576fdbde | ||
|
|
741eae59bb | ||
|
|
505494b378 | ||
|
|
c1033c12bd | ||
|
|
b789333b68 | ||
|
|
61ef73cb12 | ||
|
|
e71be7e0f1 | ||
|
|
0302c03864 | ||
|
|
d21fe54d55 | ||
|
|
a70d3ff82d | ||
|
|
574359f1df | ||
|
|
6da2f54e50 | ||
|
|
886464b2e9 | ||
|
|
396044e354 | ||
|
|
3d15202124 | ||
|
|
756b09b6b8 | ||
|
|
f4d3fadd6f | ||
|
|
1fb6e9e830 | ||
|
|
78ac6a7a29 | ||
|
|
efe8810dff | ||
|
|
5c07e11473 | ||
|
|
d552ad7673 | ||
|
|
496173da1f | ||
|
|
1cdaf33272 | ||
|
|
e2b3969492 | ||
|
|
8661bf8837 | ||
|
|
292fa7a6d2 | ||
|
|
39d7300a8e | ||
|
|
0ad20c9489 | ||
|
|
7eb3b23ddf | ||
|
|
5b06542193 | ||
|
|
38dca4f787 | ||
|
|
737d1ecf5b | ||
|
|
f819cef6d5 | ||
|
|
facae2a6db | ||
|
|
8b021c099d | ||
|
|
8a625188ce | ||
|
|
c0cfa6acde | ||
|
|
0913cfc082 | ||
|
|
59e8465325 | ||
|
|
be6e8ff77b | ||
|
|
ebb85cf843 | ||
|
|
ef959bc3c6 | ||
|
|
8cb7356bbc | ||
|
|
496545188a | ||
|
|
739c80227a | ||
|
|
6218eefd61 | ||
|
|
5715587baf | ||
|
|
ae770a625b | ||
|
|
428ee065d3 | ||
|
|
320ca28f90 | ||
|
|
5e518f5fbd | ||
|
|
24dcba1d72 | ||
|
|
1abc688cad | ||
|
|
34936189d8 | ||
|
|
daf7bf3e8b | ||
|
|
3a9f1c5796 | ||
|
|
bb1e205516 | ||
|
|
9af4a55176 | ||
|
|
51e903c34e | ||
|
|
7f03319646 | ||
|
|
0c33d18a4d | ||
|
|
a747c63b8e | ||
|
|
a03c361b04 | ||
|
|
807d0018ef | ||
|
|
850e267763 | ||
|
|
c75ae56f10 | ||
|
|
caaed775aa | ||
|
|
d11b295729 | ||
|
|
91d0059f8d | ||
|
|
44693d0dfb | ||
|
|
c722212e12 | ||
|
|
f12da65962 | ||
|
|
92745f7534 | ||
|
|
f2917aeaf8 | ||
|
|
a1e2ffd586 | ||
|
|
78a9705fad | ||
|
|
5b6da04a02 | ||
|
|
4661c2f90f | ||
|
|
90e4328885 | ||
|
|
130112a84a | ||
|
|
b368bb6ea1 | ||
|
|
9ecb6211d6 | ||
|
|
79fba9c8d3 | ||
|
|
37c76a93ab | ||
|
|
86e600aa52 | ||
|
|
6a1c28b70e | ||
|
|
9375f1809c | ||
|
|
f176150c93 | ||
|
|
30d25084f0 | ||
|
|
91ad94d941 | ||
|
|
57a778dccf | ||
|
|
f702c66659 | ||
|
|
a095468850 | ||
|
|
f9b6a20995 | ||
|
|
f2770da880 | ||
|
|
d269659e61 | ||
|
|
c4d6715443 | ||
|
|
fc4a1c5433 | ||
|
|
6bdd580b3f | ||
|
|
9cf4882f4c | ||
|
|
406dad998d | ||
|
|
8b0db22c18 | ||
|
|
f06048eccf | ||
|
|
05f5a8b61d | ||
|
|
662625a091 | ||
|
|
6328e69441 | ||
|
|
425dfb80d9 | ||
|
|
4c1fd570f0 | ||
|
|
345f853b5d | ||
|
|
100a70f87c | ||
|
|
18b591bc3b | ||
|
|
6a52b24369 | ||
|
|
228aca9523 | ||
|
|
7e4637cd70 | ||
|
|
3e3c015efa | ||
|
|
30c30b1712 | ||
|
|
e666356483 | ||
|
|
3710bc883b | ||
|
|
64f60d15b0 | ||
|
|
bc4a044337 | ||
|
|
cb233bfa66 | ||
|
|
d46059a735 | ||
|
|
da2fbd9924 | ||
|
|
084e0adb34 | ||
|
|
3bddbb6afe |
4
.github/workflows/backend-ci.yml
vendored
4
.github/workflows/backend-ci.yml
vendored
@@ -17,6 +17,7 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.1'
|
||||
@@ -36,12 +37,13 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.1'
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.11
|
||||
version: v2.9
|
||||
args: --timeout=30m
|
||||
working-directory: backend
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -78,6 +78,7 @@ Desktop.ini
|
||||
# ===================
|
||||
tmp/
|
||||
temp/
|
||||
logs/
|
||||
*.tmp
|
||||
*.temp
|
||||
*.log
|
||||
@@ -128,8 +129,15 @@ deploy/docker-compose.override.yml
|
||||
vite.config.js
|
||||
docs/*
|
||||
.serena/
|
||||
|
||||
# ===================
|
||||
# 压测工具
|
||||
# ===================
|
||||
tools/loadtest/
|
||||
# Antigravity Manager
|
||||
Antigravity-Manager/
|
||||
antigravity_projectid_fix.patch
|
||||
.codex/
|
||||
frontend/coverage/
|
||||
aicodex
|
||||
output/
|
||||
|
||||
|
||||
28
README.md
28
README.md
@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
||||
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||
- Creates `.env` file with auto-generated secrets
|
||||
- Creates data directories (uses local directories for easy backup/migration)
|
||||
@@ -522,6 +522,28 @@ sub2api/
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
> **Please read carefully before using this project:**
|
||||
>
|
||||
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||
>
|
||||
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
28
README_CN.md
28
README_CN.md
@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
||||
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||
- 创建 `.env` 文件并填充自动生成的密钥
|
||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||
@@ -588,6 +588,28 @@ sub2api/
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 免责声明
|
||||
|
||||
> **使用本项目前请仔细阅读:**
|
||||
>
|
||||
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||
>
|
||||
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.88
|
||||
0.1.96.1
|
||||
|
||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -162,7 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@@ -229,7 +229,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
|
||||
@@ -62,26 +62,28 @@ type Group struct {
|
||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||
// 是否仅允许 Claude Code 客户端
|
||||
// allow Claude Code client only
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
// fallback group for non-Claude-Code requests
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
// 无效请求兜底使用的分组 ID
|
||||
// fallback group for invalid request
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||
// model routing config: pattern -> account ids
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
// 是否启用模型路由配置
|
||||
// whether model routing is enabled
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
||||
// whether MCP XML prompt injection is enabled
|
||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||
// supported model scopes: claude, gemini_text, gemini_image
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
// group display order, lower comes first
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// simulate claude usage as claude-max style (1h cache write)
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -190,7 +192,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
@@ -431,6 +433,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.DefaultMappedModel = value.String
|
||||
}
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SimulateClaudeMaxEnabled = value.Bool
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -630,6 +638,9 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("simulate_claude_max_enabled=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -79,6 +79,8 @@ const (
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
|
||||
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -186,6 +188,7 @@ var Columns = []string{
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldDefaultMappedModel,
|
||||
FieldSimulateClaudeMaxEnabled,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -259,6 +262,8 @@ var (
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
DefaultMappedModelValidator func(string) error
|
||||
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
|
||||
DefaultSimulateClaudeMaxEnabled bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -419,6 +424,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
|
||||
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -205,6 +205,11 @@ func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
|
||||
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1555,6 +1560,16 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
|
||||
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
|
||||
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -452,6 +452,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
|
||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSimulateClaudeMaxEnabled(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -649,6 +663,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
||||
v := group.DefaultSimulateClaudeMaxEnabled
|
||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -730,6 +748,9 @@ func (_c *GroupCreate) check() error {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
||||
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -885,6 +906,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
}
|
||||
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||
_node.SimulateClaudeMaxEnabled = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1599,6 +1624,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2295,6 +2332,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSimulateClaudeMaxEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSimulateClaudeMaxEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -3157,6 +3208,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSimulateClaudeMaxEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSimulateClaudeMaxEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -653,6 +653,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
|
||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1149,6 +1163,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -2081,6 +2098,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2607,6 +2638,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -410,6 +410,7 @@ var (
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -8252,6 +8252,7 @@ type GroupMutation struct {
|
||||
addsort_order *int
|
||||
allow_messages_dispatch *bool
|
||||
default_mapped_model *string
|
||||
simulate_claude_max_enabled *bool
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -10068,6 +10069,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
|
||||
m.default_mapped_model = nil
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
|
||||
m.simulate_claude_max_enabled = &b
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
|
||||
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
|
||||
v := m.simulate_claude_max_enabled
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
|
||||
}
|
||||
return oldValue.SimulateClaudeMaxEnabled, nil
|
||||
}
|
||||
|
||||
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
|
||||
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
|
||||
m.simulate_claude_max_enabled = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -10426,7 +10463,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 32)
|
||||
fields := make([]string, 0, 33)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -10523,6 +10560,9 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.default_mapped_model != nil {
|
||||
fields = append(fields, group.FieldDefaultMappedModel)
|
||||
}
|
||||
if m.simulate_claude_max_enabled != nil {
|
||||
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -10595,6 +10635,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.AllowMessagesDispatch()
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.DefaultMappedModel()
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
return m.SimulateClaudeMaxEnabled()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -10668,6 +10710,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldAllowMessagesDispatch(ctx)
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
return m.OldSimulateClaudeMaxEnabled(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -10901,6 +10945,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetDefaultMappedModel(v)
|
||||
return nil
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetSimulateClaudeMaxEnabled(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -11334,6 +11385,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldDefaultMappedModel:
|
||||
m.ResetDefaultMappedModel()
|
||||
return nil
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
m.ResetSimulateClaudeMaxEnabled()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -463,6 +463,10 @@ func init() {
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
|
||||
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
|
||||
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
|
||||
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
|
||||
@@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin {
|
||||
|
||||
func (Group) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用
|
||||
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
||||
field.String("name").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
@@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.StatusActive),
|
||||
|
||||
// Subscription-related fields (added by migration 003)
|
||||
field.String("platform").
|
||||
MaxLen(50).
|
||||
Default(domain.PlatformAnthropic),
|
||||
@@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("default_validity_days").
|
||||
Default(30),
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
field.Float("image_price_1k").
|
||||
Optional().
|
||||
Nillable().
|
||||
@@ -87,7 +83,6 @@ func (Group) Fields() []ent.Field {
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
|
||||
// Sora 按次计费配置(阶段 1)
|
||||
field.Float("sora_image_price_360").
|
||||
Optional().
|
||||
Nillable().
|
||||
@@ -109,45 +104,38 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int64("sora_storage_quota_bytes").
|
||||
Default(0),
|
||||
|
||||
// Claude Code 客户端限制 (added by migration 029)
|
||||
field.Bool("claude_code_only").
|
||||
Default(false).
|
||||
Comment("是否仅允许 Claude Code 客户端"),
|
||||
Comment("allow Claude Code client only"),
|
||||
field.Int64("fallback_group_id").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||
Comment("fallback group for non-Claude-Code requests"),
|
||||
field.Int64("fallback_group_id_on_invalid_request").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("无效请求兜底使用的分组 ID"),
|
||||
Comment("fallback group for invalid request"),
|
||||
|
||||
// 模型路由配置 (added by migration 040)
|
||||
field.JSON("model_routing", map[string][]int64{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||
|
||||
// 模型路由开关 (added by migration 041)
|
||||
Comment("model routing config: pattern -> account ids"),
|
||||
field.Bool("model_routing_enabled").
|
||||
Default(false).
|
||||
Comment("是否启用模型路由配置"),
|
||||
Comment("whether model routing is enabled"),
|
||||
|
||||
// MCP XML 协议注入开关 (added by migration 042)
|
||||
field.Bool("mcp_xml_inject").
|
||||
Default(true).
|
||||
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
||||
Comment("whether MCP XML prompt injection is enabled"),
|
||||
|
||||
// 支持的模型系列 (added by migration 046)
|
||||
field.JSON("supported_model_scopes", []string{}).
|
||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||
Comment("supported model scopes: claude, gemini_text, gemini_image"),
|
||||
|
||||
// 分组排序 (added by migration 052)
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
Comment("group display order, lower comes first"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
@@ -157,6 +145,9 @@ func (Group) Fields() []ent.Field {
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||
field.Bool("simulate_claude_max_enabled").
|
||||
Default(false).
|
||||
Comment("simulate claude usage as claude-max style (1h cache write)"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,14 +163,11 @@ func (Group) Edges() []ent.Edge {
|
||||
edge.From("allowed_users", User.Type).
|
||||
Ref("allowed_groups").
|
||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||
}
|
||||
}
|
||||
|
||||
func (Group) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
|
||||
index.Fields("status"),
|
||||
index.Fields("platform"),
|
||||
index.Fields("subscription_type"),
|
||||
|
||||
@@ -87,6 +87,7 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
@@ -137,6 +138,8 @@ require (
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
|
||||
@@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
@@ -180,6 +182,7 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -199,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -281,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -333,6 +342,8 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -341,8 +352,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
||||
@@ -432,11 +441,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
|
||||
@@ -1402,7 +1402,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
||||
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||
|
||||
@@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
// Gemini 2.5 白名单
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
|
||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
enrichCredentialsFromIDToken(&item)
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, _ := item.Credentials["id_token"].(string)
|
||||
if strings.TrimSpace(idToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeIDToken skips expiry validation — safe for imported data
|
||||
claims, err := openai.DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := claims.GetUserInfo()
|
||||
if userInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fill missing fields only (never overwrite existing values)
|
||||
setIfMissing := func(key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||
item.Credentials[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
setIfMissing("email", userInfo.Email)
|
||||
setIfMissing("plan_type", userInfo.PlanType)
|
||||
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -626,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -656,10 +659,46 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService != nil {
|
||||
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverState handles unified recovery of recoverable account runtime state.
|
||||
// POST /api/v1/admin/accounts/:id/recover-state
|
||||
func (h *AccountHandler) RecoverState(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
if h.rateLimitService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
|
||||
InvalidateToken: true,
|
||||
}); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
|
||||
@@ -715,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -770,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -792,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
@@ -844,20 +851,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
if warning == "missing_project_id_temporary" {
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
@@ -913,14 +951,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchClearError handles batch clearing account errors
|
||||
// POST /api/v1/admin/accounts/batch-clear-error
|
||||
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, id := range req.AccountIDs {
|
||||
accountID := id // 闭包捕获
|
||||
g.Go(func() error {
|
||||
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": accountID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefresh handles batch refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/batch-refresh
|
||||
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||
foundIDs := make(map[int64]bool, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
foundIDs[acc.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
var warnings []gin.H
|
||||
|
||||
// 将不存在的账号 ID 标记为失败
|
||||
for _, id := range req.AccountIDs {
|
||||
if !foundIDs[id] {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": id,
|
||||
"error": "account not found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
if warning != "" {
|
||||
warnings = append(warnings, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"warning": warning,
|
||||
})
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
@@ -1139,6 +1338,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
|
||||
@@ -175,6 +175,10 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
|
||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCacheProbe struct {
|
||||
service.UsageLogRepository
|
||||
trendCalls atomic.Int32
|
||||
usersTrendCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
r.trendCalls.Add(1)
|
||||
return []usagestats.TrendDataPoint{{
|
||||
Date: "2026-03-11",
|
||||
Requests: 1,
|
||||
TotalTokens: 2,
|
||||
Cost: 3,
|
||||
ActualCost: 4,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
limit int,
|
||||
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
r.usersTrendCalls.Add(1)
|
||||
return []usagestats.UserUsageTrendPoint{{
|
||||
Date: "2026-03-11",
|
||||
UserID: 1,
|
||||
Email: "cache@test.dev",
|
||||
Requests: 2,
|
||||
Tokens: 20,
|
||||
Cost: 2,
|
||||
ActualCost: 1,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func resetDashboardReadCachesForTest() {
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||
}
|
||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
var (
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
)
|
||||
|
||||
type dashboardTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardModelGroupCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardEntityTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func cacheStatusValue(hit bool) string {
|
||||
if hit {
|
||||
return "hit"
|
||||
}
|
||||
return "miss"
|
||||
}
|
||||
|
||||
func mustMarshalDashboardCacheKey(value any) string {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||
typed, ok := payload.(T)
|
||||
if !ok {
|
||||
var zero T
|
||||
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUsageTrendCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getGroupStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.GroupStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||
return h.buildSnapshotV2Response(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters,
|
||||
includeStats,
|
||||
includeTrend,
|
||||
includeModels,
|
||||
includeGroups,
|
||||
includeUsersTrend,
|
||||
usersTrendLimit,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
response.Error(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
response.Success(c, cached.Payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
filters *dashboardSnapshotV2Filters,
|
||||
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||
usersTrendLimit int,
|
||||
) (*dashboardSnapshotV2Response, error) {
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get dashboard statistics")
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
trend, _, err := h.getUsageTrendCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get usage trend")
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
models, _, err := h.getModelStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get model statistics")
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
groups, _, err := h.getGroupStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get group statistics")
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get user usage trend")
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
|
||||
@@ -46,9 +46,10 @@ type CreateGroupRequest struct {
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
@@ -84,9 +85,10 @@ type UpdateGroupRequest struct {
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
@@ -207,6 +209,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
@@ -260,6 +263,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
@@ -335,6 +339,27 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
"group_available_accounts",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_rate_limited_count",
|
||||
"account_error_count",
|
||||
"account_error_ratio",
|
||||
"overload_account_count",
|
||||
}
|
||||
|
||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent":
|
||||
"memory_usage_percent",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_error_ratio":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
AutoRecover *bool `json:"auto_recover"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
plan.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
if req.AutoRecover != nil {
|
||||
existing.AutoRecover = *req.AutoRecover
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
|
||||
@@ -1348,6 +1348,118 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// GetRectifierSettings 获取请求整流器配置
|
||||
// GET /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: settings.Enabled,
|
||||
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRectifierSettingsRequest 更新整流器配置请求
|
||||
type UpdateRectifierSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// UpdateRectifierSettings 更新请求整流器配置
|
||||
// PUT /api/v1/admin/settings/rectifier
|
||||
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
var req UpdateRectifierSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.RectifierSettings{
|
||||
Enabled: req.Enabled,
|
||||
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
|
||||
}
|
||||
|
||||
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RectifierSettings{
|
||||
Enabled: updatedSettings.Enabled,
|
||||
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
|
||||
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
// GET /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||
for i, r := range settings.Rules {
|
||||
rules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||
type UpdateBetaPolicySettingsRequest struct {
|
||||
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||
// PUT /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||
var req UpdateBetaPolicySettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||
for i, r := range req.Rules {
|
||||
rules[i] = service.BetaPolicyRule(r)
|
||||
}
|
||||
|
||||
settings := &service.BetaPolicySettings{Rules: rules}
|
||||
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to return updated settings
|
||||
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||
for i, r := range updated.Rules {
|
||||
outRules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
type snapshotCacheLoadResult struct {
|
||||
Entry snapshotCacheEntry
|
||||
Hit bool
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||
if load == nil {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return entry, true, nil
|
||||
}
|
||||
if c == nil || key == "" {
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
return c.Set(key, payload), false, nil
|
||||
}
|
||||
|
||||
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||
}
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
result, ok := value.(snapshotCacheLoadResult)
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
return result.Entry, result.Hit, nil
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
|
||||
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"hello": "world"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, hit)
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
|
||||
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"unexpected": "value"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, hit)
|
||||
require.Equal(t, entry.ETag, entry2.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
start := make(chan struct{})
|
||||
const callers = 8
|
||||
errCh := make(chan error, callers)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for range callers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||
loads.Add(1)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return "value", nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
var req ResetSubscriptionQuotaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly {
|
||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||
if tokenErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||
return
|
||||
}
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", "invitation_required")
|
||||
fragment.Set("pending_oauth_token", pendingToken)
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return
|
||||
}
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
// the invitation code and creating the user account.
|
||||
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
var req completeLinuxDoOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
|
||||
@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
out := &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -98,6 +98,19 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
|
||||
t := k.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
out.Reset5hAt = &t
|
||||
}
|
||||
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
|
||||
t := k.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
out.Reset1dAt = &t
|
||||
}
|
||||
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
|
||||
t := k.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
out.Reset7dAt = &t
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
@@ -122,14 +135,15 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
@@ -255,11 +269,19 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
if out.QuotaLimit != nil {
|
||||
used := a.GetQuotaUsed()
|
||||
out.QuotaUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -475,6 +497,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
|
||||
@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -161,6 +161,26 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// RectifierSettings 请求整流器配置 DTO
|
||||
type RectifierSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置 DTO
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -57,6 +57,9 @@ type APIKey struct {
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
|
||||
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
|
||||
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
@@ -114,6 +117,8 @@ type AdminGroup struct {
|
||||
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
// Claude usage 模拟开关(仅管理员可见)
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
@@ -193,8 +198,12 @@ type Account struct {
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
// API Key 账号配额限制
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaLimit *float64 `json:"quota_limit,omitempty"`
|
||||
QuotaUsed *float64 `json:"quota_used,omitempty"`
|
||||
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
|
||||
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
@@ -315,6 +324,8 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
maxSameAccountRetries = 3
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
|
||||
@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
for i := 1; i <= maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, i, fs.SameAccountRetryCount[100])
|
||||
}
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
}
|
||||
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
// 再次遇到账号 100,计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
}
|
||||
// 再次触发时才会执行 TempUnschedule + 切换
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
for i := 0; i < maxSameAccountRetries; i++ {
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
}
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
@@ -439,6 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
@@ -630,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// ===== 用户消息串行队列 END =====
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
c.Set("parsed_request", parsedReq)
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
@@ -652,6 +654,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
@@ -734,6 +743,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
@@ -972,33 +982,45 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.EffectiveUsage5h()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
|
||||
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.EffectiveUsage1d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
|
||||
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.EffectiveUsage7d()
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
entry := gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
|
||||
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
|
||||
}
|
||||
rateLimits = append(rateLimits, entry)
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
|
||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
@@ -155,6 +156,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
|
||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
290
backend/internal/handler/openai_chat_completions.go
Normal file
290
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -212,6 +213,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -259,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
@@ -288,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -538,9 +560,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
|
||||
// 而非 OpenAI 的 session_id/conversation_id headers。
|
||||
// 从中派生 sessionHash(sticky session)和 promptCacheKey(upstream cache)。
|
||||
if sessionHash == "" || promptCacheKey == "" {
|
||||
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
|
||||
seed := reqModel + "-" + userID
|
||||
if promptCacheKey == "" {
|
||||
promptCacheKey = service.GenerateSessionUUID(seed)
|
||||
}
|
||||
if sessionHash == "" {
|
||||
sessionHash = service.DeriveSessionHashFromSeed(seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
@@ -602,6 +640,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -641,6 +680,25 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
@@ -1456,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
|
||||
if sessionHash != "" || account == nil || !account.IsPoolMode() {
|
||||
return sessionHash
|
||||
}
|
||||
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
|
||||
return "openai-pool-retry-" + uuid.NewString()
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
const (
|
||||
opsErrorLogTimeout = 5 * time.Second
|
||||
opsErrorLogDrainTimeout = 10 * time.Second
|
||||
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||
|
||||
opsErrorLogMinWorkerCount = 4
|
||||
opsErrorLogMaxWorkerCount = 32
|
||||
@@ -38,6 +39,7 @@ const (
|
||||
opsErrorLogQueueSizePerWorker = 128
|
||||
opsErrorLogMinQueueSize = 256
|
||||
opsErrorLogMaxQueueSize = 8192
|
||||
opsErrorLogBatchSize = 32
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer opsErrorLogWorkersWg.Done()
|
||||
for job := range opsErrorLogQueue {
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
for {
|
||||
job, ok := <-opsErrorLogQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||
batch = append(batch, job)
|
||||
|
||||
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < opsErrorLogBatchSize {
|
||||
select {
|
||||
case nextJob, ok := <-opsErrorLogQueue:
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch = append(batch, nextJob)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||
var processed int64
|
||||
for _, job := range batch {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||
processed++
|
||||
}
|
||||
if processed == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for opsSvc, entries := range grouped {
|
||||
if opsSvc == nil || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||
cancel()
|
||||
}
|
||||
opsErrorLogProcessed.Add(processed)
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
|
||||
@@ -2207,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -445,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
nil, // settingService
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
|
||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
|
||||
@@ -49,8 +49,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
|
||||
var defaultUserAgentVersion = "1.20.4"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
|
||||
@@ -18,6 +18,9 @@ const (
|
||||
BlockTypeFunction
|
||||
)
|
||||
|
||||
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||
type UsageMapHook func(usageMap map[string]any)
|
||||
|
||||
// StreamingProcessor 流式响应处理器
|
||||
type StreamingProcessor struct {
|
||||
blockType BlockType
|
||||
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
||||
originalModel string
|
||||
webSearchQueries []string
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
usageMapHook UsageMapHook
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||
p.usageMapHook = fn
|
||||
}
|
||||
|
||||
func usageToMap(u ClaudeUsage) map[string]any {
|
||||
m := map[string]any{
|
||||
"input_tokens": u.InputTokens,
|
||||
"output_tokens": u.OutputTokens,
|
||||
}
|
||||
if u.CacheCreationInputTokens > 0 {
|
||||
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||
}
|
||||
if u.CacheReadInputTokens > 0 {
|
||||
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
@@ -168,6 +191,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
responseID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
var usageValue any = usage
|
||||
if p.usageMapHook != nil {
|
||||
usageMap := usageToMap(usage)
|
||||
p.usageMapHook(usageMap)
|
||||
usageValue = usageMap
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"id": responseID,
|
||||
"type": "message",
|
||||
@@ -176,7 +206,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
"model": p.originalModel,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
"usage": usageValue,
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
@@ -487,13 +517,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
var usageValue any = usage
|
||||
if p.usageMapHook != nil {
|
||||
usageMap := usageToMap(usage)
|
||||
p.usageMapHook(usageMap)
|
||||
usageValue = usageMap
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": usage,
|
||||
"usage": usageValue,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
@@ -631,7 +631,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
@@ -648,7 +649,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
@@ -663,8 +665,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
// Default effort applies (high → xhigh) even when thinking is disabled.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
@@ -676,7 +679,93 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Reasoning)
|
||||
// Default effort applies (high → xhigh) when no thinking/output_config is set.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// output_config.effort override tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
// Default is xhigh, but output_config.effort="low" overrides. low→low after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "low"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "low", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
// No thinking field, but output_config.effort="medium" → creates reasoning.
|
||||
// medium→high after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "medium"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
// output_config.effort="high" → mapped to "xhigh".
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "high"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
// No output_config → default xhigh regardless of thinking.type.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: 10000},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
// output_config present but effort empty (e.g. only format set) → default xhigh.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -733,3 +822,188 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "get_weather", fn["name"])
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Image content block conversion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_UserImageBlock(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"text","text":"What is in this image?"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "What is in this image?", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageOnlyUserMessage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"/9j/4AAQ"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/jpeg;base64,/9j/4AAQ", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultWithImage(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Read the screenshot"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_1","name":"Read","input":{"file_path":"/tmp/screen.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_1","content":[
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBOR"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output (no image).
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "fc_toolu_1", items[2].CallID)
|
||||
assert.Equal(t, "(empty)", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ToolResultMixed(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Describe the file"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"toolu_2","name":"Read","input":{"file_path":"/tmp/photo.png"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"toolu_2","content":[
|
||||
{"type":"text","text":"File metadata: 800x600 PNG"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output + user(image) = 4
|
||||
require.Len(t, items, 4)
|
||||
|
||||
// function_call_output should have text-only output.
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "File metadata: 800x600 PNG", items[2].Output)
|
||||
|
||||
// Image should be in a separate user message.
|
||||
assert.Equal(t, "user", items[3].Role)
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[3].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
assert.Equal(t, "data:image/png;base64,AAAA", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_TextOnlyToolResultBackwardCompat(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Check weather"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"tool_use","id":"call_1","name":"get_weather","input":{"city":"NYC"}}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"call_1","content":[
|
||||
{"type":"text","text":"Sunny, 72°F"}
|
||||
]}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Text-only tool_result should produce a plain string.
|
||||
assert.Equal(t, "Sunny, 72°F", items[2].Output)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: json.RawMessage(`[
|
||||
{"type":"image","source":{"type":"base64","media_type":"","data":"iVBOR"}}
|
||||
]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "input_image", parts[0].Type)
|
||||
// Should default to image/png when media_type is empty.
|
||||
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
|
||||
}
|
||||
|
||||
@@ -45,18 +45,16 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
out.Tools = convertAnthropicToolsToResponses(req.Tools)
|
||||
}
|
||||
|
||||
// Convert thinking → reasoning.
|
||||
// generate_summary="auto" causes the upstream to emit reasoning_summary_text
|
||||
// streaming events; the include array only needs reasoning.encrypted_content
|
||||
// (already set above) for content continuity.
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "enabled":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "high", Summary: "auto"}
|
||||
case "adaptive":
|
||||
out.Reasoning = &ResponsesReasoning{Effort: "medium", Summary: "auto"}
|
||||
}
|
||||
// "disabled" or unknown → omit reasoning
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: mapAnthropicEffortToResponses(effort),
|
||||
Summary: "auto",
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
@@ -169,7 +167,7 @@ func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, err
|
||||
|
||||
// anthropicUserToResponses handles an Anthropic user message. Content can be a
|
||||
// plain string or an array of blocks. tool_result blocks are extracted into
|
||||
// function_call_output items.
|
||||
// function_call_output items. Image blocks are converted to input_image parts.
|
||||
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string.
|
||||
var s string
|
||||
@@ -184,28 +182,46 @@ func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error)
|
||||
}
|
||||
|
||||
var out []ResponsesInputItem
|
||||
var toolResultImageParts []ResponsesContentPart
|
||||
|
||||
// Extract tool_result blocks → function_call_output items.
|
||||
// Images inside tool_results are extracted separately because the
|
||||
// Responses API function_call_output.output only accepts strings.
|
||||
for _, b := range blocks {
|
||||
if b.Type != "tool_result" {
|
||||
continue
|
||||
}
|
||||
text := extractAnthropicToolResultText(b)
|
||||
if text == "" {
|
||||
// OpenAI Responses API requires "output" field; use placeholder for empty results.
|
||||
text = "(empty)"
|
||||
}
|
||||
outputText, imageParts := convertToolResultOutput(b)
|
||||
out = append(out, ResponsesInputItem{
|
||||
Type: "function_call_output",
|
||||
CallID: toResponsesCallID(b.ToolUseID),
|
||||
Output: text,
|
||||
Output: outputText,
|
||||
})
|
||||
toolResultImageParts = append(toolResultImageParts, imageParts...)
|
||||
}
|
||||
|
||||
// Remaining text blocks → user message.
|
||||
text := extractAnthropicTextFromBlocks(blocks)
|
||||
if text != "" {
|
||||
content, _ := json.Marshal(text)
|
||||
// Remaining text + image blocks → user message with content parts.
|
||||
// Also include images extracted from tool_results so the model can see them.
|
||||
var parts []ResponsesContentPart
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
if b.Text != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(b.Source); uri != "" {
|
||||
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, toolResultImageParts...)
|
||||
|
||||
if len(parts) > 0 {
|
||||
content, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ResponsesInputItem{Role: "user", Content: content})
|
||||
}
|
||||
|
||||
@@ -290,26 +306,64 @@ func fromResponsesCallID(id string) string {
|
||||
return id
|
||||
}
|
||||
|
||||
// extractAnthropicToolResultText gets the text content from a tool_result block.
|
||||
func extractAnthropicToolResultText(b AnthropicContentBlock) string {
|
||||
if len(b.Content) == 0 {
|
||||
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
|
||||
// Returns "" if the source is nil or has no data.
|
||||
func anthropicImageToDataURI(src *AnthropicImageSource) string {
|
||||
if src == nil || src.Data == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType := src.MediaType
|
||||
if mediaType == "" {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + src.Data
|
||||
}
|
||||
|
||||
// convertToolResultOutput extracts text and image content from a tool_result
|
||||
// block. Returns the text as a string for the function_call_output Output
|
||||
// field, plus any image parts that must be sent in a separate user message
|
||||
// (the Responses API output field only accepts strings).
|
||||
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
|
||||
if len(b.Content) == 0 {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Try plain string content.
|
||||
var s string
|
||||
if err := json.Unmarshal(b.Content, &s); err == nil {
|
||||
return s
|
||||
if s == "" {
|
||||
s = "(empty)"
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Array of content blocks — may contain text and/or images.
|
||||
var inner []AnthropicContentBlock
|
||||
if err := json.Unmarshal(b.Content, &inner); err == nil {
|
||||
var parts []string
|
||||
for _, ib := range inner {
|
||||
if ib.Type == "text" && ib.Text != "" {
|
||||
parts = append(parts, ib.Text)
|
||||
if err := json.Unmarshal(b.Content, &inner); err != nil {
|
||||
return "(empty)", nil
|
||||
}
|
||||
|
||||
// Separate text (for function_call_output) from images (for user message).
|
||||
var textParts []string
|
||||
var imageParts []ResponsesContentPart
|
||||
for _, ib := range inner {
|
||||
switch ib.Type {
|
||||
case "text":
|
||||
if ib.Text != "" {
|
||||
textParts = append(textParts, ib.Text)
|
||||
}
|
||||
case "image":
|
||||
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
|
||||
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
|
||||
text := strings.Join(textParts, "\n\n")
|
||||
if text == "" {
|
||||
text = "(empty)"
|
||||
}
|
||||
return text, imageParts
|
||||
}
|
||||
|
||||
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
|
||||
@@ -324,6 +378,23 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
// Responses API tools. Server-side tools like web_search are mapped to their
|
||||
// OpenAI equivalents; regular tools become function tools.
|
||||
|
||||
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,733 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ChatCompletionsToResponses tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4o", resp.Model)
|
||||
assert.True(t, resp.Stream) // always forced true
|
||||
assert.False(t, *resp.Store)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
assert.Equal(t, "user", items[1].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "ping",
|
||||
Arguments: `{"host":"example.com"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: json.RawMessage(`"pong"`),
|
||||
},
|
||||
},
|
||||
Tools: []ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &ChatFunction{
|
||||
Name: "ping",
|
||||
Description: "Ping a host",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "call_1", items[2].CallID)
|
||||
assert.Equal(t, "pong", items[2].Output)
|
||||
|
||||
// Check tools
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||
t.Run("max_tokens", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
// Below minMaxOutputTokens (128), should be clamped
|
||||
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||
})
|
||||
|
||||
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
maxCompletion := 500
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
MaxCompletionTokens: &maxCompletion,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ReasoningEffort: "high",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(content)},
|
||||
},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "Describe this", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
Functions: []ChatFunction{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||
|
||||
// tool_choice should be converted
|
||||
require.NotNil(t, resp.ToolChoice)
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ServiceTier: "flex",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "flex", resp.ServiceTier)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`"Let me call a function."`),
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "do_thing",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + assistant message (with text) + function_call
|
||||
require.Len(t, items, 3)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResponsesToChatCompletions tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_123",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Hello, world!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
assert.Equal(t, "chat.completion", chat.Object)
|
||||
assert.Equal(t, "gpt-4o", chat.Model)
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "Hello, world!", content)
|
||||
|
||||
require.NotNil(t, chat.Usage)
|
||||
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_xyz",
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city":"NYC"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||
|
||||
msg := chat.Choices[0].Message
|
||||
require.Len(t, msg.ToolCalls, 1)
|
||||
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_789",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "reasoning",
|
||||
Summary: []ResponsesSummary{
|
||||
{Type: "summary_text", Text: "I thought about it."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "The answer is 42."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_inc",
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "partial..."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cache",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 10,
|
||||
TotalTokens: 110,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 80,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.NotNil(t, chat.Usage)
|
||||
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_ws",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "web_search_call",
|
||||
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "search results", content)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesEventToChatChunks tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
// response.created → role chunk
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{
|
||||
ID: "resp_stream",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||
assert.True(t, state.SentRole)
|
||||
|
||||
// response.output_text.delta → content chunk
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Hello",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 1,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_1",
|
||||
Name: "get_weather",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
assert.Equal(t, "call_1", tc.ID)
|
||||
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index)
|
||||
|
||||
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||
Delta: `{"city":`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||
|
||||
// Add a second function call at output_index=2
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 2,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_2",
|
||||
Name: "get_time",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||
|
||||
// Argument delta for second tool call
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 2,
|
||||
Delta: `{"tz":"UTC"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||
|
||||
// Argument delta for first tool call (interleaved)
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1,
|
||||
Delta: `"Tokyo"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 50,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 70,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
// finish chunk + usage chunk
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// First chunk: finish_reason
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Second chunk: usage
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SawToolCall = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
state.Usage = &ChatUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
|
||||
chunks := FinalizeResponsesChatStream(state)
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||
|
||||
// Idempotent: second call returns nil
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
// Simulate response.completed
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||
|
||||
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestChatChunkToSSE(t *testing.T) {
|
||||
chunk := ChatCompletionsChunk{
|
||||
ID: "chatcmpl-test",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1700000000,
|
||||
Model: "gpt-4o",
|
||||
Choices: []ChatChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Role: "assistant"},
|
||||
FinishReason: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sse, err := ChatChunkToSSE(chunk)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, sse, "data: ")
|
||||
assert.Contains(t, sse, "chatcmpl-test")
|
||||
assert.Contains(t, sse, "assistant")
|
||||
assert.True(t, len(sse) > 10)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stream round-trip test
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
var allChunks []ChatCompletionsChunk
|
||||
|
||||
// 1. response.created
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// 2. text deltas
|
||||
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: text,
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
}
|
||||
|
||||
// 3. response.completed
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 4,
|
||||
TotalTokens: 14,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||
require.Len(t, allChunks, 7)
|
||||
|
||||
// First chunk has role
|
||||
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||
|
||||
// Text chunks
|
||||
var fullText string
|
||||
for i := 1; i <= 4; i++ {
|
||||
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||
}
|
||||
assert.Equal(t, "Hello, world!", fullText)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, allChunks[6].Usage)
|
||||
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||
|
||||
// All chunks share the same ID
|
||||
for _, c := range allChunks {
|
||||
assert.Equal(t, "resp_rt", c.ID)
|
||||
}
|
||||
}
|
||||
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
// Responses API request. The upstream always streams, so Stream is forced to
|
||||
// true. store is always false and reasoning.encrypted_content is always
|
||||
// included so that the response translator has full context.
|
||||
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: true, // upstream always streams
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
ServiceTier: req.ServiceTier,
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||
maxTokens := 0
|
||||
if req.MaxTokens != nil {
|
||||
maxTokens = *req.MaxTokens
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
maxTokens = *req.MaxCompletionTokens
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
v := maxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||
if req.ReasoningEffort != "" {
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: req.ReasoningEffort,
|
||||
Summary: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// tools[] and legacy functions[] → ResponsesTool[]
|
||||
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||
}
|
||||
|
||||
// tool_choice: already compatible format — pass through directly.
|
||||
// Legacy function_call needs mapping.
|
||||
if len(req.ToolChoice) > 0 {
|
||||
out.ToolChoice = req.ToolChoice
|
||||
} else if len(req.FunctionCall) > 0 {
|
||||
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||
// array into a Responses API input items array.
|
||||
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
for _, m := range msgs {
|
||||
items, err := chatMessageToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||
// ResponsesInputItem values.
|
||||
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
return chatSystemToResponses(m)
|
||||
case "user":
|
||||
return chatUserToResponses(m)
|
||||
case "assistant":
|
||||
return chatAssistantToResponses(m)
|
||||
case "tool":
|
||||
return chatToolToResponses(m)
|
||||
case "function":
|
||||
return chatFunctionToResponses(m)
|
||||
default:
|
||||
return chatUserToResponses(m)
|
||||
}
|
||||
}
|
||||
|
||||
// chatSystemToResponses converts a system message.
|
||||
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
text, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, err := json.Marshal(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatUserToResponses converts a user message, handling both plain strings and
|
||||
// multi-modal content arrays.
|
||||
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string first.
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||
return nil, fmt.Errorf("parse user content: %w", err)
|
||||
}
|
||||
|
||||
var responseParts []ResponsesContentPart
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "image_url":
|
||||
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: p.ImageURL.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content, err := json.Marshal(responseParts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatAssistantToResponses converts an assistant message. If there is both
|
||||
// text content and tool_calls, the text is emitted as an assistant message
|
||||
// first, then each tool_call becomes a function_call item. If the content is
|
||||
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
}
|
||||
|
||||
// Emit one function_call item per tool_call.
|
||||
for _, tc := range m.ToolCalls {
|
||||
args := tc.Function.Arguments
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.ToolCallID,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// chatFunctionToResponses converts a legacy function result message
|
||||
// (role=function) into a function_call_output item. The Name field is used as
|
||||
// call_id since legacy function calls do not carry a separate call_id.
|
||||
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.Name,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return "", fmt.Errorf("parse content as string: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||
// function definitions to Responses API tool definitions.
|
||||
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" || t.Function == nil {
|
||||
continue
|
||||
}
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: t.Function.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
// Legacy functions[] are treated as function-type tools.
|
||||
for _, f := range functions {
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
Strict: f.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||
// Responses API tool_choice value.
|
||||
//
|
||||
// "auto" → "auto"
|
||||
// "none" → "none"
|
||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Object form: {"name":"X"}
|
||||
var obj struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": obj.Name},
|
||||
})
|
||||
}
|
||||
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||
// Completions response. Text output items are concatenated into
|
||||
// choices[0].message.content; function_call items become tool_calls.
|
||||
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||
id := resp.ID
|
||||
if id == "" {
|
||||
id = generateChatCmplID()
|
||||
}
|
||||
|
||||
out := &ChatCompletionsResponse{
|
||||
ID: id,
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
contentText += part.Text
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
toolCalls = append(toolCalls, ChatToolCall{
|
||||
ID: item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: item.Name,
|
||||
Arguments: item.Arguments,
|
||||
},
|
||||
})
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
// silently consumed — results already incorporated into text output
|
||||
}
|
||||
}
|
||||
|
||||
msg := ChatMessage{Role: "assistant"}
|
||||
if len(toolCalls) > 0 {
|
||||
msg.ToolCalls = toolCalls
|
||||
}
|
||||
if contentText != "" {
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
out.Choices = []ChatChoice{{
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
}}
|
||||
|
||||
if resp.Usage != nil {
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: resp.Usage.InputTokens,
|
||||
CompletionTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
out.Usage = usage
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "length"
|
||||
}
|
||||
return "stop"
|
||||
case "completed":
|
||||
if len(toolCalls) > 0 {
|
||||
return "tool_calls"
|
||||
}
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||
// SSE events into Chat Completions SSE chunks.
|
||||
type ResponsesEventToChatState struct {
|
||||
ID string
|
||||
Model string
|
||||
Created int64
|
||||
SentRole bool
|
||||
SawToolCall bool
|
||||
SawText bool
|
||||
Finalized bool // true after finish chunk has been emitted
|
||||
NextToolCallIndex int // next sequential tool_call index to assign
|
||||
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||
IncludeUsage bool
|
||||
Usage *ChatUsage
|
||||
}
|
||||
|
||||
// NewResponsesEventToChatState returns an initialised stream state.
|
||||
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||
return &ResponsesEventToChatState{
|
||||
ID: generateChatCmplID(),
|
||||
Created: time.Now().Unix(),
|
||||
OutputIndexToToolIndex: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||
// or more Chat Completions chunks, updating state as it goes.
|
||||
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToChatHandleCreated(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToChatHandleTextDelta(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToChatHandleOutputItemAdded(evt, state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||
// this returns nil.
|
||||
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if state.Finalized {
|
||||
return nil
|
||||
}
|
||||
state.Finalized = true
|
||||
|
||||
finishReason := "stop"
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Response != nil {
|
||||
if evt.Response.ID != "" {
|
||||
state.ID = evt.Response.ID
|
||||
}
|
||||
if state.Model == "" && evt.Response.Model != "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
// Emit the role chunk.
|
||||
if state.SentRole {
|
||||
return nil
|
||||
}
|
||||
state.SentRole = true
|
||||
|
||||
role := "assistant"
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||
}
|
||||
|
||||
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
state.SawText = true
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
}
|
||||
|
||||
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.SawToolCall = true
|
||||
idx := state.NextToolCallIndex
|
||||
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||
state.NextToolCallIndex++
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
ID: evt.Item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: evt.Item.Name,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
Function: ChatFunctionCall{
|
||||
Arguments: evt.Delta,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
state.Finalized = true
|
||||
finishReason := "stop"
|
||||
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
u := evt.Response.Usage
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
state.Usage = usage
|
||||
}
|
||||
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
finishReason = "length"
|
||||
}
|
||||
case "completed":
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
}
|
||||
} else if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
var chunks []ChatCompletionsChunk
|
||||
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: delta,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||
empty := ""
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Content: &empty},
|
||||
FinishReason: &finishReason,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||
func generateChatCmplID() string {
|
||||
b := make([]byte, 12)
|
||||
_, _ = rand.Read(b)
|
||||
return "chatcmpl-" + hex.EncodeToString(b)
|
||||
}
|
||||
@@ -12,17 +12,23 @@ import "encoding/json"
|
||||
|
||||
// AnthropicRequest is the request body for POST /v1/messages.
|
||||
type AnthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSeqs []string `json:"stop_sequences,omitempty"`
|
||||
Thinking *AnthropicThinking `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicOutputConfig controls output generation parameters.
|
||||
type AnthropicOutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
|
||||
}
|
||||
|
||||
// AnthropicThinking configures extended thinking in the Anthropic API.
|
||||
@@ -47,6 +53,9 @@ type AnthropicContentBlock struct {
|
||||
// type=thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// type=image
|
||||
Source *AnthropicImageSource `json:"source,omitempty"`
|
||||
|
||||
// type=tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -58,9 +67,16 @@ type AnthropicContentBlock struct {
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// AnthropicTool describes a tool available to the model.
|
||||
type AnthropicTool struct {
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
|
||||
@@ -146,6 +162,7 @@ type ResponsesRequest struct {
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesReasoning configures reasoning effort in the Responses API.
|
||||
@@ -176,8 +193,9 @@ type ResponsesInputItem struct {
|
||||
|
||||
// ResponsesContentPart is a typed content part in a Responses message.
|
||||
type ResponsesContentPart struct {
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
|
||||
}
|
||||
|
||||
// ResponsesTool describes a tool in the Responses API.
|
||||
@@ -311,6 +329,148 @@ type ResponsesStreamEvent struct {
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Chat Completions API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||
type ChatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||
Tools []ChatTool `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||
|
||||
// Legacy function calling (deprecated but still supported)
|
||||
Functions []ChatFunction `json:"functions,omitempty"`
|
||||
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamOptions configures streaming behavior.
|
||||
type ChatStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatContentPart is a typed content part in a multi-modal message.
|
||||
type ChatContentPart struct {
|
||||
Type string `json:"type"` // "text" | "image_url"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
// ChatImageURL contains the URL for an image content part.
|
||||
type ChatImageURL struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||
}
|
||||
|
||||
// ChatTool describes a tool available to the model.
|
||||
type ChatTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Function *ChatFunction `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// ChatFunction describes a function tool definition.
|
||||
type ChatFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ChatToolCall represents a tool call made by the assistant.
|
||||
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||
type ChatToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"` // "function"
|
||||
Function ChatFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// ChatFunctionCall contains the function name and arguments.
|
||||
type ChatFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||
type ChatCompletionsResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChoice is a single completion choice.
|
||||
type ChatChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||
}
|
||||
|
||||
// ChatUsage holds token counts in Chat Completions format.
|
||||
type ChatUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ChatTokenDetails provides a breakdown of token usage.
|
||||
type ChatTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||
type ChatCompletionsChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion.chunk"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChunkChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||
type ChatChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta ChatDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||
}
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
||||
return []Model{
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package gemini
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byName := make(map[string]Model, len(models))
|
||||
for _, model := range models {
|
||||
byName[model.Name] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"models/gemini-2.5-flash-image",
|
||||
"models/gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, name := range required {
|
||||
model, ok := byName[name]
|
||||
if !ok {
|
||||
t.Fatalf("expected fallback model %q to exist", name)
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) == 0 {
|
||||
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,12 @@ type Model struct {
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
byID := make(map[string]Model, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
byID[model.ID] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, id := range required {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
claims, err := DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// UserInfo represents user information extracted from ID Token claims.
|
||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
PlanType string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -50,6 +51,18 @@ type accountRepository struct {
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||
"codex_primary_",
|
||||
"codex_secondary_",
|
||||
"codex_5h_",
|
||||
"codex_7d_",
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||
"codex_usage_updated_at": {},
|
||||
"session_window_utilization": {},
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
@@ -659,13 +672,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -925,6 +935,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1040,6 +1051,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1186,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
} else {
|
||||
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||
if len(updates) == 0 {
|
||||
return false
|
||||
}
|
||||
for key := range updates {
|
||||
if isSchedulerNeutralExtraKey(key) {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isSchedulerNeutralExtraKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
@@ -1676,13 +1724,47 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
-- 总额度:始终递增
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
-- 日额度:仅在 quota_daily_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
@@ -1704,7 +1786,7 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return err
|
||||
}
|
||||
|
||||
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
|
||||
// 任一维度配额刚超限时触发调度快照刷新
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
|
||||
@@ -1713,14 +1795,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = jsonb_set(
|
||||
COALESCE(extra, '{}'::jsonb),
|
||||
'{quota_used}',
|
||||
'0'::jsonb
|
||||
), updated_at = NOW()
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
if s.accounts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.accounts[accountID], nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
if s.accounts == nil {
|
||||
s.accounts = make(map[int64]*service.Account)
|
||||
}
|
||||
if account != nil {
|
||||
s.accounts[account.ID] = account
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -558,6 +568,26 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
@@ -603,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-neutral",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{
|
||||
accounts: map[int64]*service.Account{
|
||||
account.ID: {
|
||||
ID: account.ID,
|
||||
Platform: account.Platform,
|
||||
Status: service.StatusDisabled,
|
||||
Extra: map[string]any{
|
||||
"codex_usage_updated_at": "old",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
updates := map[string]any{
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
"codex_5h_used_percent": 88.5,
|
||||
"session_window_utilization": 0.42,
|
||||
}
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||
|
||||
var outboxCount int
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||
s.Require().Zero(outboxCount)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-codex-exhausted",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||
"codex_7d_reset_after_seconds": 86400,
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(0, count)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-mixed",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, count)
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
|
||||
@@ -164,9 +164,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSimulateClaudeMaxEnabled,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -452,6 +453,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||
query := `
|
||||
UPDATE api_keys
|
||||
SET
|
||||
quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL
|
||||
RETURNING quota_used, quota, key, status
|
||||
`
|
||||
|
||||
state := &service.APIKeyQuotaUsageState{}
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
@@ -476,8 +503,8 @@ func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
@@ -491,9 +518,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
@@ -619,6 +646,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
|
||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||
user := s.mustCreateUser("quota-state@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||
key.Quota = 3
|
||||
key.QuotaUsed = 1
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||
|
||||
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||
s.Require().NotNil(state)
|
||||
s.Require().Equal(3.5, state.QuotaUsed)
|
||||
s.Require().Equal(3.0, state.Quota)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||
s.Require().Equal(key.Key, state.Key)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(3.5, got.QuotaUsed)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
|
||||
@@ -147,17 +147,47 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||
startupCleanupScript = redis.NewScript(`
|
||||
local activePrefix = ARGV[1]
|
||||
local slotTTL = tonumber(ARGV[2])
|
||||
local removed = 0
|
||||
for i = 1, #KEYS do
|
||||
local key = KEYS[i]
|
||||
local members = redis.call('ZRANGE', key, 0, -1)
|
||||
for _, member in ipairs(members) do
|
||||
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||
removed = removed + redis.call('ZREM', key, member)
|
||||
end
|
||||
end
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, slotTTL)
|
||||
end
|
||||
end
|
||||
return removed
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
if activeRequestPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 清理有序集合中非当前进程前缀的成员
|
||||
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||
for _, pattern := range slotPatterns {
|
||||
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||
for _, pattern := range waitPatterns {
|
||||
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
return fmt.Errorf("del %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
|
||||
@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
|
||||
@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -129,7 +130,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
@@ -16,19 +16,7 @@ type opsRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
return &opsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return 0, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
q := `
|
||||
const insertOpsErrorLogSQL = `
|
||||
INSERT INTO ops_error_logs (
|
||||
request_id,
|
||||
client_request_id,
|
||||
@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||
) RETURNING id`
|
||||
)`
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
return &opsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return 0, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
var id int64
|
||||
err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
insertOpsErrorLogSQL+" RETURNING id",
|
||||
opsInsertErrorLogArgs(input)...,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = stmt.Close()
|
||||
}()
|
||||
|
||||
var inserted int64
|
||||
for _, input := range inputs {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
inserted++
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
return inserted, nil
|
||||
}
|
||||
|
||||
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
|
||||
return []any{
|
||||
opsNullString(input.RequestID),
|
||||
opsNullString(input.ClientRequestID),
|
||||
opsNullInt64(input.UserID),
|
||||
@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
|
||||
input.IsRetryable,
|
||||
input.RetryCount,
|
||||
input.CreatedAt,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||
|
||||
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||
now := time.Now().UTC()
|
||||
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||
{
|
||||
RequestID: "batch-ops-1",
|
||||
ErrorPhase: "upstream",
|
||||
ErrorType: "upstream_error",
|
||||
Severity: "error",
|
||||
StatusCode: 429,
|
||||
ErrorMessage: "rate limited",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
RequestID: "batch-ops-2",
|
||||
ErrorPhase: "internal",
|
||||
ErrorType: "api_error",
|
||||
Severity: "error",
|
||||
StatusCode: 500,
|
||||
ErrorMessage: "internal error",
|
||||
CreatedAt: now.Add(time.Millisecond),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, inserted)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(12345)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(67890)
|
||||
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ type scannable interface {
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
const schedulerOutboxDedupWindow = time.Second
|
||||
|
||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||
return &schedulerOutboxRepository{db: db}
|
||||
}
|
||||
@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
|
||||
}
|
||||
payloadArg = encoded
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, `
|
||||
query := `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, eventType, accountID, groupID, payloadArg)
|
||||
`
|
||||
args := []any{eventType, accountID, groupID, payloadArg}
|
||||
if schedulerOutboxEventSupportsDedup(eventType) {
|
||||
query = `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
SELECT $1, $2, $3, $4
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM scheduler_outbox
|
||||
WHERE event_type = $1
|
||||
AND account_id IS NOT DISTINCT FROM $2
|
||||
AND group_id IS NOT DISTINCT FROM $3
|
||||
AND created_at >= NOW() - make_interval(secs => $5)
|
||||
)
|
||||
`
|
||||
args = append(args, schedulerOutboxDedupWindow.Seconds())
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func schedulerOutboxEventSupportsDedup(eventType string) bool {
|
||||
switch eventType {
|
||||
case service.SchedulerOutboxEventAccountChanged,
|
||||
service.SchedulerOutboxEventGroupChanged,
|
||||
service.SchedulerOutboxEventFullRebuild:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
55
backend/internal/repository/simple_mode_admin_concurrency.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
|
||||
simpleModeLegacyAdminConcurrency = 5
|
||||
simpleModeTargetAdminConcurrency = 30
|
||||
)
|
||||
|
||||
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
if upgraded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.User.Update().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
|
||||
).
|
||||
SetConcurrency(simpleModeTargetAdminConcurrency).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if err := client.Setting.Create().
|
||||
SetKey(simpleModeAdminConcurrencyUpgradeKey).
|
||||
SetValue(now.Format(time.RFC3339)).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx); err != nil {
|
||||
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
@@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -1870,7 +1873,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
||||
query := `
|
||||
SELECT
|
||||
COALESCE(ul.group_id, 0) as group_id,
|
||||
COALESCE(g.name, '') as group_name,
|
||||
COALESCE(g.name, '(无分组)') as group_name,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||
@@ -2505,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
@@ -2544,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
@@ -2614,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if serviceTier.Valid {
|
||||
log.ServiceTier = &serviceTier.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Nil(t, log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
@@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "flex", *log.ServiceTier)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("service_tier_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(3),
|
||||
int64(12),
|
||||
int64(22),
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, log.ServiceTier)
|
||||
require.Equal(t, "priority", *log.ServiceTier)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.email, ugr.rate_multiplier
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||
WHERE ugr.group_id = $1
|
||||
ORDER BY ugr.user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
|
||||
@@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"allow_messages_dispatch": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -643,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -228,6 +228,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
}
|
||||
@@ -244,6 +245,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
@@ -263,6 +265,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
|
||||
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
@@ -392,6 +396,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Beta 策略配置
|
||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
@@ -447,6 +457,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
|
||||
@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
auth.POST("/oauth/linuxdo/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
||||
},
|
||||
})
|
||||
})
|
||||
// OpenAI Chat Completions API
|
||||
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
|
||||
@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPoolModeRetryCount = 3
|
||||
maxPoolModeRetryCount = 10
|
||||
)
|
||||
|
||||
// GetPoolModeRetryCount 返回池模式同账号重试次数。
|
||||
// 未配置或配置非法时回退为默认值 3;小于 0 按 0 处理;过大则截断到 10。
|
||||
func (a *Account) GetPoolModeRetryCount() int {
|
||||
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
raw, ok := a.Credentials["pool_mode_retry_count"]
|
||||
if !ok || raw == nil {
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
count := parsePoolModeRetryCount(raw)
|
||||
if count < 0 {
|
||||
return 0
|
||||
}
|
||||
if count > maxPoolModeRetryCount {
|
||||
return maxPoolModeRetryCount
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePoolModeRetryCount(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultPoolModeRetryCount
|
||||
}
|
||||
|
||||
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
|
||||
func isPoolModeRetryableStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -1134,33 +1203,97 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
|
||||
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetQuotaLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
return a.getExtraFloat64("quota_limit")
|
||||
}
|
||||
|
||||
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
|
||||
func (a *Account) GetQuotaUsed() float64 {
|
||||
return a.getExtraFloat64("quota_used")
|
||||
}
|
||||
|
||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_daily_limit")
|
||||
}
|
||||
|
||||
// GetQuotaDailyUsed 获取当日已用额度(美元)
|
||||
func (a *Account) GetQuotaDailyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_daily_used")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyLimit 获取周额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaWeeklyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_limit")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
|
||||
func (a *Account) GetQuotaWeeklyUsed() float64 {
|
||||
return a.getExtraFloat64("quota_weekly_used")
|
||||
}
|
||||
|
||||
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
|
||||
func (a *Account) getExtraFloat64(key string) float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["quota_used"]; ok {
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return false
|
||||
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
|
||||
func (a *Account) getExtraTime(key string) time.Time {
|
||||
if a.Extra == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return a.GetQuotaUsed() >= limit
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
|
||||
return t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
}
|
||||
|
||||
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur)是否已过期
|
||||
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true // 从未使用过,视为过期(下次 increment 会初始化)
|
||||
}
|
||||
return time.Since(periodStart) >= dur
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
// 总额度
|
||||
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
|
||||
117
backend/internal/service/account_pool_mode_test.go
Normal file
117
backend/internal/service/account_pool_mode_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPoolModeRetryCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "default_when_not_pool_mode",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "default_when_missing_retry_count",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "supports_float64_from_json_credentials",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": float64(5),
|
||||
},
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "supports_json_number",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": json.Number("4"),
|
||||
},
|
||||
},
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "supports_string_value",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "2",
|
||||
},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "negative_value_is_clamped_to_zero",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": -1,
|
||||
},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "oversized_value_is_clamped_to_max",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": 99,
|
||||
},
|
||||
},
|
||||
expected: maxPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "invalid_value_falls_back_to_default",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "oops",
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -68,9 +68,9 @@ type AccountRepository interface {
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
|
||||
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
|
||||
@@ -45,16 +45,23 @@ const (
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
type TestEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultGeminiTextTestPrompt = "hi"
|
||||
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
|
||||
)
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
// All account types use full Claude Code client characteristics, only auth header differs
|
||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
if account.IsGemini() {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.routeAntigravityTest(c, account, modelID)
|
||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformSora {
|
||||
@@ -406,8 +413,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||
mergeAccountExtra(account, updates)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if isOAuth && s.accountRepo != nil {
|
||||
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
@@ -416,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
@@ -443,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test payload (Gemini format)
|
||||
payload := createGeminiTestPayload()
|
||||
payload := createGeminiTestPayload(testModelID, prompt)
|
||||
|
||||
// Build request based on account type
|
||||
var req *http.Request
|
||||
@@ -1179,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
||||
|
||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||
}
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
@@ -1330,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
||||
func createGeminiTestPayload() []byte {
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API.
|
||||
// Image models use the image-generation path so the frontend can preview the returned image.
|
||||
func createGeminiTestPayload(modelID string, prompt string) []byte {
|
||||
if isImageGenerationModel(modelID) {
|
||||
imagePrompt := strings.TrimSpace(prompt)
|
||||
if imagePrompt == "" {
|
||||
imagePrompt = defaultGeminiImageTestPrompt
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": imagePrompt},
|
||||
},
|
||||
},
|
||||
},
|
||||
"generationConfig": map[string]any{
|
||||
"responseModalities": []string{"TEXT", "IMAGE"},
|
||||
"imageConfig": map[string]any{
|
||||
"aspectRatio": "1:1",
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(payload)
|
||||
return bytes
|
||||
}
|
||||
|
||||
textPrompt := strings.TrimSpace(prompt)
|
||||
if textPrompt == "" {
|
||||
textPrompt = defaultGeminiTextTestPrompt
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": "hi"},
|
||||
{"text": textPrompt},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1397,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
|
||||
mimeType, _ := inlineData["mimeType"].(string)
|
||||
data, _ := inlineData["data"].(string)
|
||||
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
|
||||
s.sendEvent(c, TestEvent{
|
||||
Type: "image",
|
||||
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
|
||||
MimeType: mimeType,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1583,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
|
||||
ginCtx, _ := gin.CreateTestContext(w)
|
||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
|
||||
|
||||
finishedAt := time.Now()
|
||||
body := w.Body.String()
|
||||
|
||||
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
|
||||
|
||||
var parsed struct {
|
||||
Contents []struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"contents"`
|
||||
GenerationConfig struct {
|
||||
ResponseModalities []string `json:"responseModalities"`
|
||||
ImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio"`
|
||||
} `json:"imageConfig"`
|
||||
} `json:"generationConfig"`
|
||||
}
|
||||
|
||||
require.NoError(t, json.Unmarshal(payload, &parsed))
|
||||
require.Len(t, parsed.Contents, 1)
|
||||
require.Len(t, parsed.Contents[0].Parts, 1)
|
||||
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
|
||||
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
|
||||
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
|
||||
}
|
||||
|
||||
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ctx, recorder := newSoraTestContext()
|
||||
svc := &AccountTestService{}
|
||||
|
||||
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||
|
||||
err := svc.processGeminiStream(ctx, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := recorder.Body.String()
|
||||
require.Contains(t, body, "\"type\":\"content\"")
|
||||
require.Contains(t, body, "\"text\":\"ok\"")
|
||||
require.Contains(t, body, "\"type\":\"image\"")
|
||||
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
|
||||
require.Contains(t, body, "\"mime_type\":\"image/png\"")
|
||||
}
|
||||
102
backend/internal/service/account_test_service_openai_test.go
Normal file
102
backend/internal/service/account_test_service_openai_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
r.rateLimitedAt = &resetAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
|
||||
`))
|
||||
resp.Header.Set("x-codex-primary-used-percent", "88")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "42")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 89,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
}
|
||||
@@ -359,6 +359,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
@@ -367,9 +368,12 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
|
||||
mergeAccountExtra(account, updates)
|
||||
if resetAt != nil {
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
@@ -409,6 +413,40 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if usage == nil {
|
||||
return true
|
||||
}
|
||||
if usage.FiveHour == nil || usage.SevenDay == nil {
|
||||
return true
|
||||
}
|
||||
if account.IsRateLimited() {
|
||||
return true
|
||||
}
|
||||
return isOpenAICodexSnapshotStale(account, now)
|
||||
}
|
||||
|
||||
func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
|
||||
if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return false
|
||||
}
|
||||
if account.Extra == nil {
|
||||
return true
|
||||
}
|
||||
raw, ok := account.Extra["codex_usage_updated_at"]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ts, err := parseTime(fmt.Sprint(raw))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return now.Sub(ts) >= openAIProbeCacheTTL
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
@@ -422,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
|
||||
if account == nil || !account.IsOAuth() {
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
return nil, nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
modelID := openaipkg.DefaultTestModel
|
||||
payload := createOpenAITestPayload(modelID, true)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
}
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -470,29 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(updates) > 0 || resetAt != nil {
|
||||
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
|
||||
return updates, resetAt, nil
|
||||
}
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
|
||||
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
if len(updates) == 0 && resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}
|
||||
if resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
|
||||
if resp == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
baseTime := time.Now()
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
|
||||
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
return updates, resetAt, nil
|
||||
}
|
||||
return nil, resetAt, nil
|
||||
}
|
||||
return nil, nil
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
|
||||
return updates, err
|
||||
}
|
||||
|
||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user