mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-11 02:24:45 +08:00
Compare commits
267 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04b4d7de3a | ||
|
|
b11cfdc11f | ||
|
|
cdddbec629 | ||
|
|
b3fe0506fb | ||
|
|
8ca0e2772e | ||
|
|
826090e099 | ||
|
|
7399de6ecc | ||
|
|
25cb5e7505 | ||
|
|
5c13ec3121 | ||
|
|
d8aff3a7e3 | ||
|
|
f44927b9f8 | ||
|
|
c0110cb5af | ||
|
|
1f8e1142a0 | ||
|
|
1e51de88d6 | ||
|
|
30995b5397 | ||
|
|
eb60f67054 | ||
|
|
78193ceec1 | ||
|
|
f0e08e7687 | ||
|
|
10b8259259 | ||
|
|
eb0b77bf4d | ||
|
|
9d81467937 | ||
|
|
fd8ccaf01a | ||
|
|
2b30e3b6d7 | ||
|
|
6e90ec6111 | ||
|
|
8dd38f4775 | ||
|
|
fbd73f248f | ||
|
|
3fcefe6c32 | ||
|
|
f740d2c291 | ||
|
|
bf6585a40f | ||
|
|
8c2dd7b3f0 | ||
|
|
4167c437a8 | ||
|
|
0ddaef3c9a | ||
|
|
2fc6aaf936 | ||
|
|
1c0519f1c7 | ||
|
|
6bbe7800be | ||
|
|
2694149489 | ||
|
|
a17ac50118 | ||
|
|
656a77d585 | ||
|
|
7455476c60 | ||
|
|
5f6e929d61 | ||
|
|
4caa3c2701 | ||
|
|
36cda57c81 | ||
|
|
9f1f203b84 | ||
|
|
b41a8ca93f | ||
|
|
e3cf0c0e10 | ||
|
|
de18bce9aa | ||
|
|
3cc407bc0e | ||
|
|
00a0a12138 | ||
|
|
b08767a4f9 | ||
|
|
ac6bde7a98 | ||
|
|
d2d41d68dd | ||
|
|
944b7f7617 | ||
|
|
53825eb073 | ||
|
|
1a7f49513f | ||
|
|
885a2ce7ef | ||
|
|
14ba80a0af | ||
|
|
5fa22fdf82 | ||
|
|
bcaae2eb91 | ||
|
|
767a41e263 | ||
|
|
252d6c5301 | ||
|
|
7a4e65ad4b | ||
|
|
a582aa89a9 | ||
|
|
acefa1da12 | ||
|
|
a88698f3fc | ||
|
|
ebc6755b33 | ||
|
|
c8eff34388 | ||
|
|
f19b03825b | ||
|
|
b43ee62947 | ||
|
|
106b20cdbf | ||
|
|
c069b3b1e8 | ||
|
|
a9ecd7bcc6 | ||
|
|
f89465fb39 | ||
|
|
440c3f46a7 | ||
|
|
c746964936 | ||
|
|
b2d6879b3f | ||
|
|
ce4095904e | ||
|
|
0d2dcbff11 | ||
|
|
02dca78dbe | ||
|
|
8df299b767 | ||
|
|
95cf59b2f6 | ||
|
|
a0f643da0e | ||
|
|
99331a5285 | ||
|
|
9245e197a8 | ||
|
|
5fa1d0978f | ||
|
|
0bb683a6f1 | ||
|
|
cb3958bac3 | ||
|
|
91ebf95efa | ||
|
|
7895dd65b2 | ||
|
|
09f8894906 | ||
|
|
45fdef9da7 | ||
|
|
704f256bd7 | ||
|
|
a6026e7ac4 | ||
|
|
1e03b2974a | ||
|
|
daa7c783b9 | ||
|
|
8a82a2a648 | ||
|
|
c3ac68af2a | ||
|
|
ec3897b981 | ||
|
|
62486cee37 | ||
|
|
7c5746ffbc | ||
|
|
47f7b0213b | ||
|
|
8df42f7aab | ||
|
|
d666e05a6d | ||
|
|
c37edf2de5 | ||
|
|
bb9af2465e | ||
|
|
7af00864b3 | ||
|
|
81903e87e3 | ||
|
|
0b96c7a65e | ||
|
|
34ccfe45ea | ||
|
|
d4231150a9 | ||
|
|
2268e93aec | ||
|
|
c976916b6d | ||
|
|
0bbd003a18 | ||
|
|
336a844712 | ||
|
|
51e7f262bd | ||
|
|
99663a3f20 | ||
|
|
4d88248091 | ||
|
|
c143367e15 | ||
|
|
ca44389baa | ||
|
|
05c2a65ef0 | ||
|
|
3f21d204a7 | ||
|
|
c848f950ce | ||
|
|
87bd765a57 | ||
|
|
49325a769e | ||
|
|
238d86f502 | ||
|
|
7064063230 | ||
|
|
b1de4352a8 | ||
|
|
7bf5c1cbcb | ||
|
|
411e24146d | ||
|
|
c7392fc80b | ||
|
|
c37c68a341 | ||
|
|
9230d3cbc9 | ||
|
|
19925e22d9 | ||
|
|
65e0c1b258 | ||
|
|
94bdde32bb | ||
|
|
14c80d26c6 | ||
|
|
9555a99d1c | ||
|
|
0e69895603 | ||
|
|
cc3cf1d70a | ||
|
|
3382d496e3 | ||
|
|
e0b4b00dc1 | ||
|
|
81d896bf78 | ||
|
|
ec576fdbde | ||
|
|
741eae59bb | ||
|
|
505494b378 | ||
|
|
c1033c12bd | ||
|
|
b789333b68 | ||
|
|
61ef73cb12 | ||
|
|
e71be7e0f1 | ||
|
|
0302c03864 | ||
|
|
d21fe54d55 | ||
|
|
a70d3ff82d | ||
|
|
574359f1df | ||
|
|
6da2f54e50 | ||
|
|
886464b2e9 | ||
|
|
396044e354 | ||
|
|
3d15202124 | ||
|
|
756b09b6b8 | ||
|
|
f4d3fadd6f | ||
|
|
1fb6e9e830 | ||
|
|
78ac6a7a29 | ||
|
|
efe8810dff | ||
|
|
5c07e11473 | ||
|
|
d552ad7673 | ||
|
|
496173da1f | ||
|
|
1cdaf33272 | ||
|
|
e2b3969492 | ||
|
|
8661bf8837 | ||
|
|
292fa7a6d2 | ||
|
|
39d7300a8e | ||
|
|
0ad20c9489 | ||
|
|
7eb3b23ddf | ||
|
|
5b06542193 | ||
|
|
38dca4f787 | ||
|
|
737d1ecf5b | ||
|
|
f819cef6d5 | ||
|
|
facae2a6db | ||
|
|
8b021c099d | ||
|
|
8a625188ce | ||
|
|
c0cfa6acde | ||
|
|
0913cfc082 | ||
|
|
59e8465325 | ||
|
|
be6e8ff77b | ||
|
|
ebb85cf843 | ||
|
|
ef959bc3c6 | ||
|
|
8cb7356bbc | ||
|
|
496545188a | ||
|
|
739c80227a | ||
|
|
6218eefd61 | ||
|
|
5715587baf | ||
|
|
ae770a625b | ||
|
|
428ee065d3 | ||
|
|
320ca28f90 | ||
|
|
5e518f5fbd | ||
|
|
24dcba1d72 | ||
|
|
1abc688cad | ||
|
|
34936189d8 | ||
|
|
daf7bf3e8b | ||
|
|
3a9f1c5796 | ||
|
|
bb1e205516 | ||
|
|
9af4a55176 | ||
|
|
51e903c34e | ||
|
|
7f03319646 | ||
|
|
0c33d18a4d | ||
|
|
a747c63b8e | ||
|
|
a03c361b04 | ||
|
|
807d0018ef | ||
|
|
850e267763 | ||
|
|
c75ae56f10 | ||
|
|
caaed775aa | ||
|
|
d11b295729 | ||
|
|
91d0059f8d | ||
|
|
44693d0dfb | ||
|
|
c722212e12 | ||
|
|
f12da65962 | ||
|
|
92745f7534 | ||
|
|
f2917aeaf8 | ||
|
|
a1e2ffd586 | ||
|
|
78a9705fad | ||
|
|
5b6da04a02 | ||
|
|
4661c2f90f | ||
|
|
90e4328885 | ||
|
|
130112a84a | ||
|
|
b368bb6ea1 | ||
|
|
9ecb6211d6 | ||
|
|
79fba9c8d3 | ||
|
|
37c76a93ab | ||
|
|
86e600aa52 | ||
|
|
6a1c28b70e | ||
|
|
9375f1809c | ||
|
|
f176150c93 | ||
|
|
30d25084f0 | ||
|
|
91ad94d941 | ||
|
|
57a778dccf | ||
|
|
f702c66659 | ||
|
|
a095468850 | ||
|
|
f9b6a20995 | ||
|
|
f2770da880 | ||
|
|
d269659e61 | ||
|
|
c4d6715443 | ||
|
|
fc4a1c5433 | ||
|
|
6bdd580b3f | ||
|
|
9cf4882f4c | ||
|
|
406dad998d | ||
|
|
8b0db22c18 | ||
|
|
f06048eccf | ||
|
|
05f5a8b61d | ||
|
|
662625a091 | ||
|
|
6328e69441 | ||
|
|
425dfb80d9 | ||
|
|
4c1fd570f0 | ||
|
|
345f853b5d | ||
|
|
100a70f87c | ||
|
|
18b591bc3b | ||
|
|
6a52b24369 | ||
|
|
228aca9523 | ||
|
|
7e4637cd70 | ||
|
|
3e3c015efa | ||
|
|
30c30b1712 | ||
|
|
e666356483 | ||
|
|
3710bc883b | ||
|
|
64f60d15b0 | ||
|
|
bc4a044337 | ||
|
|
cb233bfa66 | ||
|
|
d46059a735 | ||
|
|
da2fbd9924 | ||
|
|
084e0adb34 | ||
|
|
3bddbb6afe |
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -17,6 +17,7 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.1'
|
||||
@@ -36,6 +37,7 @@ jobs:
|
||||
go-version-file: backend/go.mod
|
||||
check-latest: false
|
||||
cache: true
|
||||
cache-dependency-path: backend/go.sum
|
||||
- name: Verify Go version
|
||||
run: |
|
||||
go version | grep -q 'go1.26.1'
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -78,6 +78,7 @@ Desktop.ini
|
||||
# ===================
|
||||
tmp/
|
||||
temp/
|
||||
logs/
|
||||
*.tmp
|
||||
*.temp
|
||||
*.log
|
||||
@@ -128,8 +129,15 @@ deploy/docker-compose.override.yml
|
||||
vite.config.js
|
||||
docs/*
|
||||
.serena/
|
||||
|
||||
# ===================
|
||||
# 压测工具
|
||||
# ===================
|
||||
tools/loadtest/
|
||||
# Antigravity Manager
|
||||
Antigravity-Manager/
|
||||
antigravity_projectid_fix.patch
|
||||
.codex/
|
||||
frontend/coverage/
|
||||
aicodex
|
||||
output/
|
||||
|
||||
|
||||
28
README.md
28
README.md
@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# Start services
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**What the script does:**
|
||||
- Downloads `docker-compose.local.yml` and `.env.example`
|
||||
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
|
||||
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
|
||||
- Creates `.env` file with auto-generated secrets
|
||||
- Creates data directories (uses local directories for easy backup/migration)
|
||||
@@ -522,6 +522,28 @@ sub2api/
|
||||
└── install.sh # One-click installation script
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
> **Please read carefully before using this project:**
|
||||
>
|
||||
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
|
||||
>
|
||||
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
28
README_CN.md
28
README_CN.md
@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
|
||||
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
|
||||
|
||||
# 启动服务
|
||||
docker-compose -f docker-compose.local.yml up -d
|
||||
docker-compose up -d
|
||||
|
||||
# 查看日志
|
||||
docker-compose -f docker-compose.local.yml logs -f sub2api
|
||||
docker-compose logs -f sub2api
|
||||
```
|
||||
|
||||
**脚本功能:**
|
||||
- 下载 `docker-compose.local.yml` 和 `.env.example`
|
||||
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml`)和 `.env.example`
|
||||
- 自动生成安全凭证(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)
|
||||
- 创建 `.env` 文件并填充自动生成的密钥
|
||||
- 创建数据目录(使用本地目录,便于备份和迁移)
|
||||
@@ -588,6 +588,28 @@ sub2api/
|
||||
└── install.sh # 一键安装脚本
|
||||
```
|
||||
|
||||
## 免责声明
|
||||
|
||||
> **使用本项目前请仔细阅读:**
|
||||
>
|
||||
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
|
||||
>
|
||||
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
|
||||
|
||||
---
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.1.88
|
||||
0.1.96.1
|
||||
|
||||
@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
|
||||
@@ -62,26 +62,28 @@ type Group struct {
|
||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||
// 是否仅允许 Claude Code 客户端
|
||||
// allow Claude Code client only
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
// fallback group for non-Claude-Code requests
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
// 无效请求兜底使用的分组 ID
|
||||
// fallback group for invalid request
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||
// model routing config: pattern -> account ids
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
// 是否启用模型路由配置
|
||||
// whether model routing is enabled
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
||||
// whether MCP XML prompt injection is enabled
|
||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||
// supported model scopes: claude, gemini_text, gemini_image
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
// 分组显示排序,数值越小越靠前
|
||||
// group display order, lower comes first
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// simulate claude usage as claude-max style (1h cache write)
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -190,7 +192,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
@@ -431,6 +433,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.DefaultMappedModel = value.String
|
||||
}
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
|
||||
} else if value.Valid {
|
||||
_m.SimulateClaudeMaxEnabled = value.Bool
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -630,6 +638,9 @@ func (_m *Group) String() string {
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("default_mapped_model=")
|
||||
builder.WriteString(_m.DefaultMappedModel)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("simulate_claude_max_enabled=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -79,6 +79,8 @@ const (
|
||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
|
||||
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -186,6 +188,7 @@ var Columns = []string{
|
||||
FieldSortOrder,
|
||||
FieldAllowMessagesDispatch,
|
||||
FieldDefaultMappedModel,
|
||||
FieldSimulateClaudeMaxEnabled,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -259,6 +262,8 @@ var (
|
||||
DefaultDefaultMappedModel string
|
||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
DefaultMappedModelValidator func(string) error
|
||||
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
|
||||
DefaultSimulateClaudeMaxEnabled bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -419,6 +424,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
|
||||
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -205,6 +205,11 @@ func DefaultMappedModel(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
|
||||
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -1555,6 +1560,16 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
|
||||
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
|
||||
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -452,6 +452,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
|
||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetSimulateClaudeMaxEnabled(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -649,6 +663,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultMappedModel
|
||||
_c.mutation.SetDefaultMappedModel(v)
|
||||
}
|
||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
||||
v := group.DefaultSimulateClaudeMaxEnabled
|
||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -730,6 +748,9 @@ func (_c *GroupCreate) check() error {
|
||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
||||
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -885,6 +906,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
_node.DefaultMappedModel = value
|
||||
}
|
||||
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||
_node.SimulateClaudeMaxEnabled = value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1599,6 +1624,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -2295,6 +2332,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSimulateClaudeMaxEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSimulateClaudeMaxEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -3157,6 +3208,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetSimulateClaudeMaxEnabled(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateSimulateClaudeMaxEnabled()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -653,6 +653,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
|
||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1149,6 +1163,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -2081,6 +2098,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -2607,6 +2638,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -410,6 +410,7 @@ var (
|
||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -8252,6 +8252,7 @@ type GroupMutation struct {
|
||||
addsort_order *int
|
||||
allow_messages_dispatch *bool
|
||||
default_mapped_model *string
|
||||
simulate_claude_max_enabled *bool
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -10068,6 +10069,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
|
||||
m.default_mapped_model = nil
|
||||
}
|
||||
|
||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
||||
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
|
||||
m.simulate_claude_max_enabled = &b
|
||||
}
|
||||
|
||||
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
|
||||
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
|
||||
v := m.simulate_claude_max_enabled
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
|
||||
}
|
||||
return oldValue.SimulateClaudeMaxEnabled, nil
|
||||
}
|
||||
|
||||
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
|
||||
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
|
||||
m.simulate_claude_max_enabled = nil
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -10426,7 +10463,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 32)
|
||||
fields := make([]string, 0, 33)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -10523,6 +10560,9 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.default_mapped_model != nil {
|
||||
fields = append(fields, group.FieldDefaultMappedModel)
|
||||
}
|
||||
if m.simulate_claude_max_enabled != nil {
|
||||
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -10595,6 +10635,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.AllowMessagesDispatch()
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.DefaultMappedModel()
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
return m.SimulateClaudeMaxEnabled()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -10668,6 +10710,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldAllowMessagesDispatch(ctx)
|
||||
case group.FieldDefaultMappedModel:
|
||||
return m.OldDefaultMappedModel(ctx)
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
return m.OldSimulateClaudeMaxEnabled(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -10901,6 +10945,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetDefaultMappedModel(v)
|
||||
return nil
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetSimulateClaudeMaxEnabled(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -11334,6 +11385,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldDefaultMappedModel:
|
||||
m.ResetDefaultMappedModel()
|
||||
return nil
|
||||
case group.FieldSimulateClaudeMaxEnabled:
|
||||
m.ResetSimulateClaudeMaxEnabled()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -463,6 +463,10 @@ func init() {
|
||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
|
||||
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
|
||||
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
|
||||
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||
_ = idempotencyrecordMixinFields0
|
||||
|
||||
@@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin {
|
||||
|
||||
func (Group) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用
|
||||
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
||||
field.String("name").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
@@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field {
|
||||
MaxLen(20).
|
||||
Default(domain.StatusActive),
|
||||
|
||||
// Subscription-related fields (added by migration 003)
|
||||
field.String("platform").
|
||||
MaxLen(50).
|
||||
Default(domain.PlatformAnthropic),
|
||||
@@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int("default_validity_days").
|
||||
Default(30),
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
field.Float("image_price_1k").
|
||||
Optional().
|
||||
Nillable().
|
||||
@@ -87,7 +83,6 @@ func (Group) Fields() []ent.Field {
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
|
||||
// Sora 按次计费配置(阶段 1)
|
||||
field.Float("sora_image_price_360").
|
||||
Optional().
|
||||
Nillable().
|
||||
@@ -109,45 +104,38 @@ func (Group) Fields() []ent.Field {
|
||||
field.Int64("sora_storage_quota_bytes").
|
||||
Default(0),
|
||||
|
||||
// Claude Code 客户端限制 (added by migration 029)
|
||||
field.Bool("claude_code_only").
|
||||
Default(false).
|
||||
Comment("是否仅允许 Claude Code 客户端"),
|
||||
Comment("allow Claude Code client only"),
|
||||
field.Int64("fallback_group_id").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||
Comment("fallback group for non-Claude-Code requests"),
|
||||
field.Int64("fallback_group_id_on_invalid_request").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("无效请求兜底使用的分组 ID"),
|
||||
Comment("fallback group for invalid request"),
|
||||
|
||||
// 模型路由配置 (added by migration 040)
|
||||
field.JSON("model_routing", map[string][]int64{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||
|
||||
// 模型路由开关 (added by migration 041)
|
||||
Comment("model routing config: pattern -> account ids"),
|
||||
field.Bool("model_routing_enabled").
|
||||
Default(false).
|
||||
Comment("是否启用模型路由配置"),
|
||||
Comment("whether model routing is enabled"),
|
||||
|
||||
// MCP XML 协议注入开关 (added by migration 042)
|
||||
field.Bool("mcp_xml_inject").
|
||||
Default(true).
|
||||
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
||||
Comment("whether MCP XML prompt injection is enabled"),
|
||||
|
||||
// 支持的模型系列 (added by migration 046)
|
||||
field.JSON("supported_model_scopes", []string{}).
|
||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||
Comment("supported model scopes: claude, gemini_text, gemini_image"),
|
||||
|
||||
// 分组排序 (added by migration 052)
|
||||
field.Int("sort_order").
|
||||
Default(0).
|
||||
Comment("分组显示排序,数值越小越靠前"),
|
||||
Comment("group display order, lower comes first"),
|
||||
|
||||
// OpenAI Messages 调度配置 (added by migration 069)
|
||||
field.Bool("allow_messages_dispatch").
|
||||
@@ -157,6 +145,9 @@ func (Group) Fields() []ent.Field {
|
||||
MaxLen(100).
|
||||
Default("").
|
||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||
field.Bool("simulate_claude_max_enabled").
|
||||
Default(false).
|
||||
Comment("simulate claude usage as claude-max style (1h cache write)"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,14 +163,11 @@ func (Group) Edges() []ent.Edge {
|
||||
edge.From("allowed_users", User.Type).
|
||||
Ref("allowed_groups").
|
||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||
}
|
||||
}
|
||||
|
||||
func (Group) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
|
||||
index.Fields("status"),
|
||||
index.Fields("platform"),
|
||||
index.Fields("subscription_type"),
|
||||
|
||||
@@ -87,6 +87,7 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
@@ -137,6 +138,8 @@ require (
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
|
||||
@@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
@@ -180,6 +182,7 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@@ -199,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -281,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -333,6 +342,8 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -341,8 +352,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
||||
@@ -432,11 +441,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
|
||||
@@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-haiku-4-5": "claude-sonnet-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
|
||||
// Gemini 2.5 白名单
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
|
||||
@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
|
||||
}
|
||||
}
|
||||
|
||||
enrichCredentialsFromIDToken(&item)
|
||||
|
||||
accountInput := &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
|
||||
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
|
||||
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
|
||||
// Existing credential values are never overwritten — only missing fields are filled.
|
||||
func enrichCredentialsFromIDToken(item *DataAccount) {
|
||||
if item.Credentials == nil {
|
||||
return
|
||||
}
|
||||
// Only enrich OpenAI/Sora OAuth accounts
|
||||
platform := strings.ToLower(strings.TrimSpace(item.Platform))
|
||||
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
|
||||
return
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, _ := item.Credentials["id_token"].(string)
|
||||
if strings.TrimSpace(idToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeIDToken skips expiry validation — safe for imported data
|
||||
claims, err := openai.DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := claims.GetUserInfo()
|
||||
if userInfo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fill missing fields only (never overwrite existing values)
|
||||
setIfMissing := func(key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if existing, _ := item.Credentials[key].(string); existing == "" {
|
||||
item.Credentials[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
setIfMissing("email", userInfo.Email)
|
||||
setIfMissing("plan_type", userInfo.PlanType)
|
||||
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
|
||||
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
|
||||
setIfMissing("organization_id", userInfo.OrganizationID)
|
||||
}
|
||||
|
||||
func normalizeProxyStatus(status string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(status))
|
||||
switch normalized {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -626,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type SyncFromCRSRequest struct {
|
||||
@@ -656,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
@@ -751,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Only refresh OAuth-based accounts (oauth and setup-token)
|
||||
// refreshSingleAccount refreshes credentials for a single OAuth account.
|
||||
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
|
||||
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
|
||||
if !account.IsOAuth() {
|
||||
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
|
||||
return
|
||||
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
|
||||
}
|
||||
|
||||
var newCredentials map[string]any
|
||||
|
||||
if account.IsOpenAI() {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Build new credentials from token info
|
||||
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// Preserve non-token settings from existing credentials
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -806,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformAntigravity {
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
@@ -828,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
return updatedAccount, "missing_project_id_temporary", nil
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
|
||||
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
@@ -880,20 +851,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
// Refresh handles refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/:id/refresh
|
||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
if warning == "missing_project_id_temporary" {
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
@@ -949,14 +951,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchClearError handles batch clearing account errors
|
||||
// POST /api/v1/admin/accounts/batch-clear-error
|
||||
func (h *AccountHandler) BatchClearError(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, id := range req.AccountIDs {
|
||||
accountID := id // 闭包捕获
|
||||
g.Go(func() error {
|
||||
account, err := h.adminService.ClearAccountError(gctx, accountID)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": accountID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
|
||||
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchRefresh handles batch refreshing account credentials
|
||||
// POST /api/v1/admin/accounts/batch-refresh
|
||||
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
|
||||
var req struct {
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.BadRequest(c, "account_ids is required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 建立已获取账号的 ID 集合,检测缺失的 ID
|
||||
foundIDs := make(map[int64]bool, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
foundIDs[acc.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
var successCount, failedCount int
|
||||
var errors []gin.H
|
||||
var warnings []gin.H
|
||||
|
||||
// 将不存在的账号 ID 标记为失败
|
||||
for _, id := range req.AccountIDs {
|
||||
if !foundIDs[id] {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": id,
|
||||
"error": "account not found",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:所有 goroutine 必须 return nil,避免 errgroup cancel 其他并发任务
|
||||
for _, account := range accounts {
|
||||
acc := account // 闭包捕获
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
_, warning, err := h.refreshSingleAccount(gctx, acc)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
failedCount++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
successCount++
|
||||
if warning != "" {
|
||||
warnings = append(warnings, gin.H{
|
||||
"account_id": acc.ID,
|
||||
"warning": warning,
|
||||
})
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": len(req.AccountIDs),
|
||||
"success": successCount,
|
||||
"failed": failedCount,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
// POST /api/v1/admin/accounts/batch
|
||||
func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
@@ -1175,6 +1338,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
|
||||
@@ -175,6 +175,10 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"models": stats,
|
||||
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
|
||||
limit = 12
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"trend": trend,
|
||||
|
||||
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
118
backend/internal/handler/admin/dashboard_handler_cache_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCacheProbe struct {
|
||||
service.UsageLogRepository
|
||||
trendCalls atomic.Int32
|
||||
usersTrendCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
r.trendCalls.Add(1)
|
||||
return []usagestats.TrendDataPoint{{
|
||||
Date: "2026-03-11",
|
||||
Requests: 1,
|
||||
TotalTokens: 2,
|
||||
Cost: 3,
|
||||
ActualCost: 4,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
limit int,
|
||||
) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
r.usersTrendCalls.Add(1)
|
||||
return []usagestats.UserUsageTrendPoint{{
|
||||
Date: "2026-03-11",
|
||||
UserID: 1,
|
||||
Email: "cache@test.dev",
|
||||
Requests: 2,
|
||||
Tokens: 20,
|
||||
Cost: 2,
|
||||
ActualCost: 1,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func resetDashboardReadCachesForTest() {
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.trendCalls.Load())
|
||||
}
|
||||
|
||||
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
|
||||
t.Cleanup(resetDashboardReadCachesForTest)
|
||||
resetDashboardReadCachesForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &dashboardUsageRepoCacheProbe{}
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
|
||||
}
|
||||
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
200
backend/internal/handler/admin/dashboard_query_cache.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
var (
|
||||
dashboardTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
|
||||
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
|
||||
)
|
||||
|
||||
type dashboardTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardModelGroupCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
}
|
||||
|
||||
type dashboardEntityTrendCacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
func cacheStatusValue(hit bool) string {
|
||||
if hit {
|
||||
return "hit"
|
||||
}
|
||||
return "miss"
|
||||
}
|
||||
|
||||
func mustMarshalDashboardCacheKey(value any) string {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func snapshotPayloadAs[T any](payload any) (T, error) {
|
||||
typed, ok := payload.(T)
|
||||
if !ok {
|
||||
var zero T
|
||||
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUsageTrendCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getModelStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getGroupStatsCached(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.GroupStat, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
})
|
||||
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
|
||||
return stats, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
|
||||
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
Limit: limit,
|
||||
})
|
||||
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
|
||||
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, hit, err
|
||||
}
|
||||
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
|
||||
return trend, hit, err
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
|
||||
return h.buildSnapshotV2Response(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters,
|
||||
includeStats,
|
||||
includeTrend,
|
||||
includeModels,
|
||||
includeGroups,
|
||||
includeUsersTrend,
|
||||
usersTrendLimit,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
response.Error(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
|
||||
response.Success(c, cached.Payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) buildSnapshotV2Response(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
filters *dashboardSnapshotV2Filters,
|
||||
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
|
||||
usersTrendLimit int,
|
||||
) (*dashboardSnapshotV2Response, error) {
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
stats, err := h.dashboardService.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get dashboard statistics")
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
trend, _, err := h.getUsageTrendCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get usage trend")
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
models, _, err := h.getModelStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get model statistics")
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
groups, _, err := h.getGroupStatsCached(
|
||||
ctx,
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
return nil, errors.New("failed to get group statistics")
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
return nil, errors.New("failed to get user usage trend")
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
|
||||
@@ -46,9 +46,10 @@ type CreateGroupRequest struct {
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
@@ -84,9 +85,10 @@ type UpdateGroupRequest struct {
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
@@ -207,6 +209,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
@@ -260,6 +263,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||
@@ -335,6 +339,27 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent",
|
||||
"concurrency_queue_depth",
|
||||
"group_available_accounts",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_rate_limited_count",
|
||||
"account_error_count",
|
||||
"account_error_ratio",
|
||||
"overload_account_count",
|
||||
}
|
||||
|
||||
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
|
||||
"error_rate",
|
||||
"upstream_error_rate",
|
||||
"cpu_usage_percent",
|
||||
"memory_usage_percent":
|
||||
"memory_usage_percent",
|
||||
"group_available_ratio",
|
||||
"group_rate_limit_ratio",
|
||||
"account_error_ratio":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -1405,6 +1405,61 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetBetaPolicySettings 获取 Beta 策略配置
|
||||
// GET /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
|
||||
for i, r := range settings.Rules {
|
||||
rules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: rules})
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
|
||||
type UpdateBetaPolicySettingsRequest struct {
|
||||
Rules []dto.BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// UpdateBetaPolicySettings 更新 Beta 策略配置
|
||||
// PUT /api/v1/admin/settings/beta-policy
|
||||
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
|
||||
var req UpdateBetaPolicySettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rules := make([]service.BetaPolicyRule, len(req.Rules))
|
||||
for i, r := range req.Rules {
|
||||
rules[i] = service.BetaPolicyRule(r)
|
||||
}
|
||||
|
||||
settings := &service.BetaPolicySettings{Rules: rules}
|
||||
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to return updated settings
|
||||
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
|
||||
for i, r := range updated.Rules {
|
||||
outRules[i] = dto.BetaPolicyRule(r)
|
||||
}
|
||||
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
@@ -19,6 +21,12 @@ type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
type snapshotCacheLoadResult struct {
|
||||
Entry snapshotCacheEntry
|
||||
Hit bool
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
return entry
|
||||
}
|
||||
|
||||
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
|
||||
if load == nil {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return entry, true, nil
|
||||
}
|
||||
if c == nil || key == "" {
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
return c.Set(key, payload), false, nil
|
||||
}
|
||||
|
||||
value, err, _ := c.sf.Do(key, func() (any, error) {
|
||||
if entry, ok := c.Get(key); ok {
|
||||
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
|
||||
}
|
||||
payload, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return snapshotCacheEntry{}, false, err
|
||||
}
|
||||
result, ok := value.(snapshotCacheLoadResult)
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false, nil
|
||||
}
|
||||
return result.Entry, result.Hit, nil
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
|
||||
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"hello": "world"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, hit)
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
|
||||
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
|
||||
loads.Add(1)
|
||||
return map[string]string{"unexpected": "value"}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, hit)
|
||||
require.Equal(t, entry.ETag, entry2.ETag)
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
var loads atomic.Int32
|
||||
start := make(chan struct{})
|
||||
const callers = 8
|
||||
errCh := make(chan error, callers)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for range callers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, _, err := c.GetOrLoad("shared", func() (any, error) {
|
||||
loads.Add(1)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
return "value", nil
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), loads.Load())
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid subscription ID")
|
||||
return
|
||||
}
|
||||
var req ResetSubscriptionQuotaRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly {
|
||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
// DELETE /api/v1/admin/subscriptions/:id
|
||||
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
|
||||
if tokenErr != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
|
||||
return
|
||||
}
|
||||
fragment := url.Values{}
|
||||
fragment.Set("error", "invitation_required")
|
||||
fragment.Set("pending_oauth_token", pendingToken)
|
||||
fragment.Set("redirect", redirectTo)
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return
|
||||
}
|
||||
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
// the invitation code and creating the user account.
|
||||
// POST /api/v1/auth/oauth/linuxdo/complete-registration
|
||||
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
var req completeLinuxDoOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
|
||||
if h != nil && h.settingSvc != nil {
|
||||
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
|
||||
|
||||
@@ -135,14 +135,15 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
SortOrder: g.SortOrder,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
|
||||
@@ -168,6 +168,19 @@ type RectifierSettings struct {
|
||||
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
|
||||
}
|
||||
|
||||
// BetaPolicyRule Beta 策略规则 DTO
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// BetaPolicySettings Beta 策略配置 DTO
|
||||
type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -117,6 +117,8 @@ type AdminGroup struct {
|
||||
|
||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
// Claude usage 模拟开关(仅管理员可见)
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
|
||||
@@ -439,6 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
@@ -630,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// ===== 用户消息串行队列 END =====
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
c.Set("parsed_request", parsedReq)
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
@@ -652,6 +654,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
@@ -734,6 +743,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
|
||||
@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
|
||||
@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
290
backend/internal/handler/openai_chat_completions.go
Normal file
290
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
const (
|
||||
opsErrorLogTimeout = 5 * time.Second
|
||||
opsErrorLogDrainTimeout = 10 * time.Second
|
||||
opsErrorLogBatchWindow = 200 * time.Millisecond
|
||||
|
||||
opsErrorLogMinWorkerCount = 4
|
||||
opsErrorLogMaxWorkerCount = 32
|
||||
@@ -38,6 +39,7 @@ const (
|
||||
opsErrorLogQueueSizePerWorker = 128
|
||||
opsErrorLogMinQueueSize = 256
|
||||
opsErrorLogMaxQueueSize = 8192
|
||||
opsErrorLogBatchSize = 32
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer opsErrorLogWorkersWg.Done()
|
||||
for job := range opsErrorLogQueue {
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
for {
|
||||
job, ok := <-opsErrorLogQueue
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
|
||||
batch = append(batch, job)
|
||||
|
||||
timer := time.NewTimer(opsErrorLogBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < opsErrorLogBatchSize {
|
||||
select {
|
||||
case nextJob, ok := <-opsErrorLogQueue:
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
opsErrorLogQueueLen.Add(-1)
|
||||
batch = append(batch, nextJob)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
flushOpsErrorLogBatch(batch)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
|
||||
var processed int64
|
||||
for _, job := range batch {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
grouped[job.ops] = append(grouped[job.ops], job.entry)
|
||||
processed++
|
||||
}
|
||||
if processed == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for opsSvc, entries := range grouped {
|
||||
if opsSvc == nil || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = opsSvc.RecordErrorBatch(ctx, entries)
|
||||
cancel()
|
||||
}
|
||||
opsErrorLogProcessed.Add(processed)
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
|
||||
@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
|
||||
@@ -18,6 +18,9 @@ const (
|
||||
BlockTypeFunction
|
||||
)
|
||||
|
||||
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||
type UsageMapHook func(usageMap map[string]any)
|
||||
|
||||
// StreamingProcessor 流式响应处理器
|
||||
type StreamingProcessor struct {
|
||||
blockType BlockType
|
||||
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
||||
originalModel string
|
||||
webSearchQueries []string
|
||||
groundingChunks []GeminiGroundingChunk
|
||||
usageMapHook UsageMapHook
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||
p.usageMapHook = fn
|
||||
}
|
||||
|
||||
func usageToMap(u ClaudeUsage) map[string]any {
|
||||
m := map[string]any{
|
||||
"input_tokens": u.InputTokens,
|
||||
"output_tokens": u.OutputTokens,
|
||||
}
|
||||
if u.CacheCreationInputTokens > 0 {
|
||||
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||
}
|
||||
if u.CacheReadInputTokens > 0 {
|
||||
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
@@ -168,6 +191,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
responseID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
var usageValue any = usage
|
||||
if p.usageMapHook != nil {
|
||||
usageMap := usageToMap(usage)
|
||||
p.usageMapHook(usageMap)
|
||||
usageValue = usageMap
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"id": responseID,
|
||||
"type": "message",
|
||||
@@ -176,7 +206,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
"model": p.originalModel,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
"usage": usageValue,
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
@@ -487,13 +517,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
var usageValue any = usage
|
||||
if p.usageMapHook != nil {
|
||||
usageMap := usageToMap(usage)
|
||||
p.usageMapHook(usageMap)
|
||||
usageValue = usageMap
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": usage,
|
||||
"usage": usageValue,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
733
backend/internal/pkg/apicompat/chatcompletions_responses_test.go
Normal file
@@ -0,0 +1,733 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ChatCompletionsToResponses tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hello"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4o", resp.Model)
|
||||
assert.True(t, resp.Stream) // always forced true
|
||||
assert.False(t, *resp.Store)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "system", items[0].Role)
|
||||
assert.Equal(t, "user", items[1].Role)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "ping",
|
||||
Arguments: `{"host":"example.com"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: json.RawMessage(`"pong"`),
|
||||
},
|
||||
},
|
||||
Tools: []ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &ChatFunction{
|
||||
Name: "ping",
|
||||
Description: "Ping a host",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + function_call + function_call_output = 3
|
||||
// (assistant message with empty content + tool_calls → only function_call items emitted)
|
||||
require.Len(t, items, 3)
|
||||
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "call_1", items[2].CallID)
|
||||
assert.Equal(t, "pong", items[2].Output)
|
||||
|
||||
// Check tools
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "ping", resp.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
|
||||
t.Run("max_tokens", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
// Below minMaxOutputTokens (128), should be clamped
|
||||
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
|
||||
})
|
||||
|
||||
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
|
||||
maxTokens := 100
|
||||
maxCompletion := 500
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
MaxTokens: &maxTokens,
|
||||
MaxCompletionTokens: &maxCompletion,
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.MaxOutputTokens)
|
||||
assert.Equal(t, 500, *resp.MaxOutputTokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ReasoningEffort: "high",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(content)},
|
||||
},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
|
||||
require.Len(t, parts, 2)
|
||||
assert.Equal(t, "input_text", parts[0].Type)
|
||||
assert.Equal(t, "Describe this", parts[0].Text)
|
||||
assert.Equal(t, "input_image", parts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
},
|
||||
Functions: []ChatFunction{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
},
|
||||
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Equal(t, "function", resp.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", resp.Tools[0].Name)
|
||||
|
||||
// tool_choice should be converted
|
||||
require.NotNil(t, resp.ToolChoice)
|
||||
var tc map[string]any
|
||||
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
ServiceTier: "flex",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "flex", resp.ServiceTier)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Do something"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`"Let me call a function."`),
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_abc",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "do_thing",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
// user + assistant message (with text) + function_call
|
||||
require.Len(t, items, 3)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResponsesToChatCompletions tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_123",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "Hello, world!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
assert.Equal(t, "chat.completion", chat.Object)
|
||||
assert.Equal(t, "gpt-4o", chat.Model)
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "Hello, world!", content)
|
||||
|
||||
require.NotNil(t, chat.Usage)
|
||||
assert.Equal(t, 10, chat.Usage.PromptTokens)
|
||||
assert.Equal(t, 5, chat.Usage.CompletionTokens)
|
||||
assert.Equal(t, 15, chat.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_456",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "function_call",
|
||||
CallID: "call_xyz",
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city":"NYC"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
|
||||
|
||||
msg := chat.Choices[0].Message
|
||||
require.Len(t, msg.ToolCalls, 1)
|
||||
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
|
||||
assert.Equal(t, "function", msg.ToolCalls[0].Type)
|
||||
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
|
||||
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_789",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "reasoning",
|
||||
Summary: []ResponsesSummary{
|
||||
{Type: "summary_text", Text: "I thought about it."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "The answer is 42."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_inc",
|
||||
Status: "incomplete",
|
||||
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{
|
||||
{Type: "output_text", Text: "partial..."},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "length", chat.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_cache",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
|
||||
},
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 10,
|
||||
TotalTokens: 110,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 80,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.NotNil(t, chat.Usage)
|
||||
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_ws",
|
||||
Status: "completed",
|
||||
Output: []ResponsesOutput{
|
||||
{
|
||||
Type: "web_search_call",
|
||||
Action: &WebSearchAction{Type: "search", Query: "test"},
|
||||
},
|
||||
{
|
||||
Type: "message",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||
require.Len(t, chat.Choices, 1)
|
||||
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
assert.Equal(t, "search results", content)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesEventToChatChunks tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
// response.created → role chunk
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{
|
||||
ID: "resp_stream",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
|
||||
assert.True(t, state.SentRole)
|
||||
|
||||
// response.output_text.delta → content chunk
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "Hello",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 1,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_1",
|
||||
Name: "get_weather",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
|
||||
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
assert.Equal(t, "call_1", tc.ID)
|
||||
assert.Equal(t, "get_weather", tc.Function.Name)
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index)
|
||||
|
||||
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1, // matches the output_index from output_item.added above
|
||||
Delta: `{"city":`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
|
||||
assert.Equal(t, `{"city":`, tc.Function.Arguments)
|
||||
|
||||
// Add a second function call at output_index=2
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_item.added",
|
||||
OutputIndex: 2,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
CallID: "call_2",
|
||||
Name: "get_time",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
|
||||
|
||||
// Argument delta for second tool call
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 2,
|
||||
Delta: `{"tz":"UTC"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
|
||||
|
||||
// Argument delta for first tool call (interleaved)
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.function_call_arguments.delta",
|
||||
OutputIndex: 1,
|
||||
Delta: `"Tokyo"}`,
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
|
||||
require.NotNil(t, tc.Index)
|
||||
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 50,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 70,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
// finish chunk + usage chunk
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// First chunk: finish_reason
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Second chunk: usage
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
|
||||
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
|
||||
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SawToolCall = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
},
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
state.Usage = &ChatUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
|
||||
chunks := FinalizeResponsesChatStream(state)
|
||||
require.Len(t, chunks, 2)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
|
||||
|
||||
// Idempotent: second call returns nil
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
|
||||
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
|
||||
// must be a no-op (prevents double finish_reason being sent to the client).
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
// Simulate response.completed
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
require.NotEmpty(t, chunks) // finish + usage chunks
|
||||
|
||||
// Now FinalizeResponsesChatStream should return nil — already finalized.
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestChatChunkToSSE(t *testing.T) {
|
||||
chunk := ChatCompletionsChunk{
|
||||
ID: "chatcmpl-test",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1700000000,
|
||||
Model: "gpt-4o",
|
||||
Choices: []ChatChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Role: "assistant"},
|
||||
FinishReason: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sse, err := ChatChunkToSSE(chunk)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, sse, "data: ")
|
||||
assert.Contains(t, sse, "chatcmpl-test")
|
||||
assert.Contains(t, sse, "assistant")
|
||||
assert.True(t, len(sse) > 10)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stream round-trip test
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
|
||||
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
|
||||
// Verify that the streaming state machine produces correct chat completions chunks.
|
||||
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
var allChunks []ChatCompletionsChunk
|
||||
|
||||
// 1. response.created
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.created",
|
||||
Response: &ResponsesResponse{ID: "resp_rt"},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// 2. text deltas
|
||||
for _, text := range []string{"Hello", ", ", "world", "!"} {
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: text,
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
}
|
||||
|
||||
// 3. response.completed
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 4,
|
||||
TotalTokens: 14,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
allChunks = append(allChunks, chunks...)
|
||||
|
||||
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
|
||||
require.Len(t, allChunks, 7)
|
||||
|
||||
// First chunk has role
|
||||
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
|
||||
|
||||
// Text chunks
|
||||
var fullText string
|
||||
for i := 1; i <= 4; i++ {
|
||||
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
|
||||
fullText += *allChunks[i].Choices[0].Delta.Content
|
||||
}
|
||||
assert.Equal(t, "Hello, world!", fullText)
|
||||
|
||||
// Finish chunk
|
||||
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
|
||||
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
|
||||
|
||||
// Usage chunk
|
||||
require.NotNil(t, allChunks[6].Usage)
|
||||
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
|
||||
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
|
||||
|
||||
// All chunks share the same ID
|
||||
for _, c := range allChunks {
|
||||
assert.Equal(t, "resp_rt", c.ID)
|
||||
}
|
||||
}
|
||||
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
312
backend/internal/pkg/apicompat/chatcompletions_to_responses.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
// Responses API request. The upstream always streams, so Stream is forced to
|
||||
// true. store is always false and reasoning.encrypted_content is always
|
||||
// included so that the response translator has full context.
|
||||
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
|
||||
input, err := convertChatMessagesToResponsesInput(req.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: true, // upstream always streams
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
ServiceTier: req.ServiceTier,
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
|
||||
maxTokens := 0
|
||||
if req.MaxTokens != nil {
|
||||
maxTokens = *req.MaxTokens
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
maxTokens = *req.MaxCompletionTokens
|
||||
}
|
||||
if maxTokens > 0 {
|
||||
v := maxTokens
|
||||
if v < minMaxOutputTokens {
|
||||
v = minMaxOutputTokens
|
||||
}
|
||||
out.MaxOutputTokens = &v
|
||||
}
|
||||
|
||||
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
|
||||
if req.ReasoningEffort != "" {
|
||||
out.Reasoning = &ResponsesReasoning{
|
||||
Effort: req.ReasoningEffort,
|
||||
Summary: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// tools[] and legacy functions[] → ResponsesTool[]
|
||||
if len(req.Tools) > 0 || len(req.Functions) > 0 {
|
||||
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
|
||||
}
|
||||
|
||||
// tool_choice: already compatible format — pass through directly.
|
||||
// Legacy function_call needs mapping.
|
||||
if len(req.ToolChoice) > 0 {
|
||||
out.ToolChoice = req.ToolChoice
|
||||
} else if len(req.FunctionCall) > 0 {
|
||||
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert function_call: %w", err)
|
||||
}
|
||||
out.ToolChoice = tc
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// convertChatMessagesToResponsesInput converts the Chat Completions messages
|
||||
// array into a Responses API input items array.
|
||||
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var out []ResponsesInputItem
|
||||
for _, m := range msgs {
|
||||
items, err := chatMessageToResponsesItems(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, items...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// chatMessageToResponsesItems converts a single ChatMessage into one or more
|
||||
// ResponsesInputItem values.
|
||||
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
return chatSystemToResponses(m)
|
||||
case "user":
|
||||
return chatUserToResponses(m)
|
||||
case "assistant":
|
||||
return chatAssistantToResponses(m)
|
||||
case "tool":
|
||||
return chatToolToResponses(m)
|
||||
case "function":
|
||||
return chatFunctionToResponses(m)
|
||||
default:
|
||||
return chatUserToResponses(m)
|
||||
}
|
||||
}
|
||||
|
||||
// chatSystemToResponses converts a system message.
|
||||
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
text, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, err := json.Marshal(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatUserToResponses converts a user message, handling both plain strings and
|
||||
// multi-modal content arrays.
|
||||
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string first.
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||
return nil, fmt.Errorf("parse user content: %w", err)
|
||||
}
|
||||
|
||||
var responseParts []ResponsesContentPart
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "image_url":
|
||||
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: p.ImageURL.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content, err := json.Marshal(responseParts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
// chatAssistantToResponses converts an assistant message. If there is both
|
||||
// text content and tool_calls, the text is emitted as an assistant message
|
||||
// first, then each tool_call becomes a function_call item. If the content is
|
||||
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var items []ResponsesInputItem
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
}
|
||||
|
||||
// Emit one function_call item per tool_call.
|
||||
for _, tc := range m.ToolCalls {
|
||||
args := tc.Function.Arguments
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
items = append(items, ResponsesInputItem{
|
||||
Type: "function_call",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.ToolCallID,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// chatFunctionToResponses converts a legacy function result message
|
||||
// (role=function) into a function_call_output item. The Name field is used as
|
||||
// call_id since legacy function calls do not carry a separate call_id.
|
||||
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
output, err := parseChatContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if output == "" {
|
||||
output = "(empty)"
|
||||
}
|
||||
return []ResponsesInputItem{{
|
||||
Type: "function_call_output",
|
||||
CallID: m.Name,
|
||||
Output: output,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return "", fmt.Errorf("parse content as string: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||
// function definitions to Responses API tool definitions.
|
||||
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
|
||||
var out []ResponsesTool
|
||||
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" || t.Function == nil {
|
||||
continue
|
||||
}
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: t.Function.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
// Legacy functions[] are treated as function-type tools.
|
||||
for _, f := range functions {
|
||||
rt := ResponsesTool{
|
||||
Type: "function",
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
Strict: f.Strict,
|
||||
}
|
||||
out = append(out, rt)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
|
||||
// Responses API tool_choice value.
|
||||
//
|
||||
// "auto" → "auto"
|
||||
// "none" → "none"
|
||||
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
|
||||
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
|
||||
// Try string first ("auto", "none", etc.) — pass through as-is.
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Object form: {"name":"X"}
|
||||
var obj struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{"name": obj.Name},
|
||||
})
|
||||
}
|
||||
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
368
backend/internal/pkg/apicompat/responses_to_chatcompletions.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesToChatCompletions converts a Responses API response into a Chat
|
||||
// Completions response. Text output items are concatenated into
|
||||
// choices[0].message.content; function_call items become tool_calls.
|
||||
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
|
||||
id := resp.ID
|
||||
if id == "" {
|
||||
id = generateChatCmplID()
|
||||
}
|
||||
|
||||
out := &ChatCompletionsResponse{
|
||||
ID: id,
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, part := range item.Content {
|
||||
if part.Type == "output_text" && part.Text != "" {
|
||||
contentText += part.Text
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
toolCalls = append(toolCalls, ChatToolCall{
|
||||
ID: item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: item.Name,
|
||||
Arguments: item.Arguments,
|
||||
},
|
||||
})
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
// silently consumed — results already incorporated into text output
|
||||
}
|
||||
}
|
||||
|
||||
msg := ChatMessage{Role: "assistant"}
|
||||
if len(toolCalls) > 0 {
|
||||
msg.ToolCalls = toolCalls
|
||||
}
|
||||
if contentText != "" {
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
out.Choices = []ChatChoice{{
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
}}
|
||||
|
||||
if resp.Usage != nil {
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: resp.Usage.InputTokens,
|
||||
CompletionTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
out.Usage = usage
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
|
||||
switch status {
|
||||
case "incomplete":
|
||||
if details != nil && details.Reason == "max_output_tokens" {
|
||||
return "length"
|
||||
}
|
||||
return "stop"
|
||||
case "completed":
|
||||
if len(toolCalls) > 0 {
|
||||
return "tool_calls"
|
||||
}
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ResponsesEventToChatState tracks state for converting a sequence of Responses
|
||||
// SSE events into Chat Completions SSE chunks.
|
||||
type ResponsesEventToChatState struct {
|
||||
ID string
|
||||
Model string
|
||||
Created int64
|
||||
SentRole bool
|
||||
SawToolCall bool
|
||||
SawText bool
|
||||
Finalized bool // true after finish chunk has been emitted
|
||||
NextToolCallIndex int // next sequential tool_call index to assign
|
||||
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
|
||||
IncludeUsage bool
|
||||
Usage *ChatUsage
|
||||
}
|
||||
|
||||
// NewResponsesEventToChatState returns an initialised stream state.
|
||||
func NewResponsesEventToChatState() *ResponsesEventToChatState {
|
||||
return &ResponsesEventToChatState{
|
||||
ID: generateChatCmplID(),
|
||||
Created: time.Now().Unix(),
|
||||
OutputIndexToToolIndex: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
|
||||
// or more Chat Completions chunks, updating state as it goes.
|
||||
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
switch evt.Type {
|
||||
case "response.created":
|
||||
return resToChatHandleCreated(evt, state)
|
||||
case "response.output_text.delta":
|
||||
return resToChatHandleTextDelta(evt, state)
|
||||
case "response.output_item.added":
|
||||
return resToChatHandleOutputItemAdded(evt, state)
|
||||
case "response.function_call_arguments.delta":
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
|
||||
// stream ended without a proper completion event (e.g. upstream disconnect).
|
||||
// It is idempotent: if a completion event already emitted the finish chunk,
|
||||
// this returns nil.
|
||||
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if state.Finalized {
|
||||
return nil
|
||||
}
|
||||
state.Finalized = true
|
||||
|
||||
finishReason := "stop"
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
|
||||
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("data: %s\n\n", data), nil
|
||||
}
|
||||
|
||||
// --- internal handlers ---
|
||||
|
||||
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Response != nil {
|
||||
if evt.Response.ID != "" {
|
||||
state.ID = evt.Response.ID
|
||||
}
|
||||
if state.Model == "" && evt.Response.Model != "" {
|
||||
state.Model = evt.Response.Model
|
||||
}
|
||||
}
|
||||
// Emit the role chunk.
|
||||
if state.SentRole {
|
||||
return nil
|
||||
}
|
||||
state.SentRole = true
|
||||
|
||||
role := "assistant"
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
|
||||
}
|
||||
|
||||
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
state.SawText = true
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
}
|
||||
|
||||
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Item == nil || evt.Item.Type != "function_call" {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.SawToolCall = true
|
||||
idx := state.NextToolCallIndex
|
||||
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
|
||||
state.NextToolCallIndex++
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
ID: evt.Item.CallID,
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: evt.Item.Name,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
|
||||
ToolCalls: []ChatToolCall{{
|
||||
Index: &idx,
|
||||
Function: ChatFunctionCall{
|
||||
Arguments: evt.Delta,
|
||||
},
|
||||
}},
|
||||
})}
|
||||
}
|
||||
|
||||
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
state.Finalized = true
|
||||
finishReason := "stop"
|
||||
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
u := evt.Response.Usage
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
state.Usage = usage
|
||||
}
|
||||
|
||||
switch evt.Response.Status {
|
||||
case "incomplete":
|
||||
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
|
||||
finishReason = "length"
|
||||
}
|
||||
case "completed":
|
||||
if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
}
|
||||
} else if state.SawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
var chunks []ChatCompletionsChunk
|
||||
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
|
||||
|
||||
if state.IncludeUsage && state.Usage != nil {
|
||||
chunks = append(chunks, ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{},
|
||||
Usage: state.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: delta,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
|
||||
empty := ""
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: state.Created,
|
||||
Model: state.Model,
|
||||
Choices: []ChatChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: ChatDelta{Content: &empty},
|
||||
FinishReason: &finishReason,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
|
||||
func generateChatCmplID() string {
|
||||
b := make([]byte, 12)
|
||||
_, _ = rand.Read(b)
|
||||
return "chatcmpl-" + hex.EncodeToString(b)
|
||||
}
|
||||
@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
|
||||
SequenceNumber int `json:"sequence_number,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI Chat Completions API types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
|
||||
type ChatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
|
||||
Tools []ChatTool `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
|
||||
|
||||
// Legacy function calling (deprecated but still supported)
|
||||
Functions []ChatFunction `json:"functions,omitempty"`
|
||||
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamOptions configures streaming behavior.
|
||||
type ChatStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ChatContentPart is a typed content part in a multi-modal message.
|
||||
type ChatContentPart struct {
|
||||
Type string `json:"type"` // "text" | "image_url"
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ChatImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
// ChatImageURL contains the URL for an image content part.
|
||||
type ChatImageURL struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
|
||||
}
|
||||
|
||||
// ChatTool describes a tool available to the model.
|
||||
type ChatTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Function *ChatFunction `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// ChatFunction describes a function tool definition.
|
||||
type ChatFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ChatToolCall represents a tool call made by the assistant.
|
||||
// Index is only populated in streaming chunks (omitted in non-streaming responses).
|
||||
type ChatToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"` // "function"
|
||||
Function ChatFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// ChatFunctionCall contains the function name and arguments.
|
||||
type ChatFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
|
||||
type ChatCompletionsResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChoice is a single completion choice.
|
||||
type ChatChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
|
||||
}
|
||||
|
||||
// ChatUsage holds token counts in Chat Completions format.
|
||||
type ChatUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ChatTokenDetails provides a breakdown of token usage.
|
||||
type ChatTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||
type ChatCompletionsChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"` // "chat.completion.chunk"
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChunkChoice `json:"choices"`
|
||||
Usage *ChatUsage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// ChatChunkChoice is a single choice in a streaming chunk.
|
||||
type ChatChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta ChatDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"` // pointer: null when not final
|
||||
}
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaFastMode}
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -18,10 +18,12 @@ func DefaultModels() []Model {
|
||||
return []Model{
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
backend/internal/pkg/gemini/models_test.go
Normal file
28
backend/internal/pkg/gemini/models_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package gemini
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byName := make(map[string]Model, len(models))
|
||||
for _, model := range models {
|
||||
byName[model.Name] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"models/gemini-2.5-flash-image",
|
||||
"models/gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, name := range required {
|
||||
model, ok := byName[name]
|
||||
if !ok {
|
||||
t.Fatalf("expected fallback model %q to exist", name)
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) == 0 {
|
||||
t.Fatalf("expected fallback model %q to advertise generation methods", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,12 @@ type Model struct {
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
23
backend/internal/pkg/geminicli/models_test.go
Normal file
23
backend/internal/pkg/geminicli/models_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
byID := make(map[string]Model, len(DefaultModels))
|
||||
for _, model := range DefaultModels {
|
||||
byID[model.ID] = model
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for _, id := range required {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected curated Gemini model %q to exist", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,6 +268,7 @@ type IDTokenClaims struct {
|
||||
type OpenAIAuthClaims struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id"`
|
||||
ChatGPTPlanType string `json:"chatgpt_plan_type"`
|
||||
UserID string `json:"user_id"`
|
||||
Organizations []OrganizationClaim `json:"organizations"`
|
||||
}
|
||||
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
|
||||
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
|
||||
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
claims, err := DecodeIDToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// UserInfo represents user information extracted from ID Token claims.
|
||||
@@ -375,6 +387,7 @@ type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
ChatGPTUserID string
|
||||
PlanType string
|
||||
UserID string
|
||||
OrganizationID string
|
||||
Organizations []OrganizationClaim
|
||||
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
|
||||
if c.OpenAIAuth != nil {
|
||||
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
|
||||
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
|
||||
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
|
||||
info.UserID = c.OpenAIAuth.UserID
|
||||
info.Organizations = c.OpenAIAuth.Organizations
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -50,6 +51,18 @@ type accountRepository struct {
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||
"codex_primary_",
|
||||
"codex_secondary_",
|
||||
"codex_5h_",
|
||||
"codex_7d_",
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeys = map[string]struct{}{
|
||||
"codex_usage_updated_at": {},
|
||||
"session_window_utilization": {},
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
@@ -1185,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
} else {
|
||||
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
|
||||
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
|
||||
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
|
||||
if len(updates) == 0 {
|
||||
return false
|
||||
}
|
||||
for key := range updates {
|
||||
if isSchedulerNeutralExtraKey(key) {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isSchedulerNeutralExtraKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if _, ok := schedulerNeutralExtraKeys[key]; ok {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
|
||||
@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
if s.accounts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.accounts[accountID], nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
if s.accounts == nil {
|
||||
s.accounts = make(map[int64]*service.Account)
|
||||
}
|
||||
if account != nil {
|
||||
s.accounts[account.ID] = account
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -623,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-neutral",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{
|
||||
accounts: map[int64]*service.Account{
|
||||
account.ID: {
|
||||
ID: account.ID,
|
||||
Platform: account.Platform,
|
||||
Status: service.StatusDisabled,
|
||||
Extra: map[string]any{
|
||||
"codex_usage_updated_at": "old",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
updates := map[string]any{
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
"codex_5h_used_percent": 88.5,
|
||||
"session_window_utilization": 0.42,
|
||||
}
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||
|
||||
var outboxCount int
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||
s.Require().Zero(outboxCount)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-codex-exhausted",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||
"codex_7d_reset_after_seconds": 86400,
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(0, count)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-mixed",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, count)
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
|
||||
@@ -164,6 +164,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSimulateClaudeMaxEnabled,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
@@ -452,6 +453,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||
query := `
|
||||
UPDATE api_keys
|
||||
SET
|
||||
quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL
|
||||
RETURNING quota_used, quota, key, status
|
||||
`
|
||||
|
||||
state := &service.APIKeyQuotaUsageState{}
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
@@ -619,6 +646,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
|
||||
@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||
user := s.mustCreateUser("quota-state@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||
key.Quota = 3
|
||||
key.QuotaUsed = 1
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||
|
||||
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||
s.Require().NotNil(state)
|
||||
s.Require().Equal(3.5, state.QuotaUsed)
|
||||
s.Require().Equal(3.0, state.Quota)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||
s.Require().Equal(key.Key, state.Key)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(3.5, got.QuotaUsed)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
|
||||
@@ -147,17 +147,47 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||
startupCleanupScript = redis.NewScript(`
|
||||
local activePrefix = ARGV[1]
|
||||
local slotTTL = tonumber(ARGV[2])
|
||||
local removed = 0
|
||||
for i = 1, #KEYS do
|
||||
local key = KEYS[i]
|
||||
local members = redis.call('ZRANGE', key, 0, -1)
|
||||
for _, member in ipairs(members) do
|
||||
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||
removed = removed + redis.call('ZREM', key, member)
|
||||
end
|
||||
end
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, slotTTL)
|
||||
end
|
||||
end
|
||||
return removed
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
if activeRequestPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 清理有序集合中非当前进程前缀的成员
|
||||
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||
for _, pattern := range slotPatterns {
|
||||
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||
for _, pattern := range waitPatterns {
|
||||
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
return fmt.Errorf("del %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
|
||||
@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -129,7 +130,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
|
||||
@@ -16,19 +16,7 @@ type opsRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
return &opsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return 0, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
q := `
|
||||
const insertOpsErrorLogSQL = `
|
||||
INSERT INTO ops_error_logs (
|
||||
request_id,
|
||||
client_request_id,
|
||||
@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||
) RETURNING id`
|
||||
)`
|
||||
|
||||
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||
return &opsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return 0, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
var id int64
|
||||
err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
insertOpsErrorLogSQL+" RETURNING id",
|
||||
opsInsertErrorLogArgs(input)...,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = stmt.Close()
|
||||
}()
|
||||
|
||||
var inserted int64
|
||||
for _, input := range inputs {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
inserted++
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
return inserted, nil
|
||||
}
|
||||
|
||||
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
|
||||
return []any{
|
||||
opsNullString(input.RequestID),
|
||||
opsNullString(input.ClientRequestID),
|
||||
opsNullInt64(input.UserID),
|
||||
@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
|
||||
input.IsRetryable,
|
||||
input.RetryCount,
|
||||
input.CreatedAt,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||
|
||||
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||
now := time.Now().UTC()
|
||||
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||
{
|
||||
RequestID: "batch-ops-1",
|
||||
ErrorPhase: "upstream",
|
||||
ErrorType: "upstream_error",
|
||||
Severity: "error",
|
||||
StatusCode: 429,
|
||||
ErrorMessage: "rate limited",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
RequestID: "batch-ops-2",
|
||||
ErrorPhase: "internal",
|
||||
ErrorType: "api_error",
|
||||
Severity: "error",
|
||||
StatusCode: 500,
|
||||
ErrorMessage: "internal error",
|
||||
CreatedAt: now.Add(time.Millisecond),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, inserted)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(12345)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(67890)
|
||||
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
const schedulerOutboxDedupWindow = time.Second
|
||||
|
||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||
return &schedulerOutboxRepository{db: db}
|
||||
}
|
||||
@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
|
||||
}
|
||||
payloadArg = encoded
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, `
|
||||
query := `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, eventType, accountID, groupID, payloadArg)
|
||||
`
|
||||
args := []any{eventType, accountID, groupID, payloadArg}
|
||||
if schedulerOutboxEventSupportsDedup(eventType) {
|
||||
query = `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
SELECT $1, $2, $3, $4
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM scheduler_outbox
|
||||
WHERE event_type = $1
|
||||
AND account_id IS NOT DISTINCT FROM $2
|
||||
AND group_id IS NOT DISTINCT FROM $3
|
||||
AND created_at >= NOW() - make_interval(secs => $5)
|
||||
)
|
||||
`
|
||||
args = append(args, schedulerOutboxDedupWindow.Seconds())
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func schedulerOutboxEventSupportsDedup(eventType string) bool {
|
||||
switch eventType {
|
||||
case service.SchedulerOutboxEventAccountChanged,
|
||||
service.SchedulerOutboxEventGroupChanged,
|
||||
service.SchedulerOutboxEventFullRebuild:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1873,7 +1873,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
||||
query := `
|
||||
SELECT
|
||||
COALESCE(ul.group_id, 0) as group_id,
|
||||
COALESCE(g.name, '') as group_name,
|
||||
COALESCE(g.name, '(无分组)') as group_name,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||
|
||||
@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.email, ugr.rate_multiplier
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||
WHERE ugr.group_id = $1
|
||||
ORDER BY ugr.user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
|
||||
@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -228,6 +228,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
}
|
||||
@@ -264,6 +265,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
|
||||
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
@@ -396,6 +399,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Beta 策略配置
|
||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
@@ -451,6 +457,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
|
||||
@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
auth.POST("/oauth/linuxdo/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
||||
},
|
||||
})
|
||||
})
|
||||
// OpenAI Chat Completions API
|
||||
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
|
||||
@@ -45,16 +45,23 @@ const (
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
type TestEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultGeminiTextTestPrompt = "hi"
|
||||
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
|
||||
)
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
// All account types use full Claude Code client characteristics, only auth header differs
|
||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
if account.IsGemini() {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.routeAntigravityTest(c, account, modelID)
|
||||
return s.routeAntigravityTest(c, account, modelID, prompt)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformSora {
|
||||
@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test payload (Gemini format)
|
||||
payload := createGeminiTestPayload()
|
||||
payload := createGeminiTestPayload(testModelID, prompt)
|
||||
|
||||
// Build request based on account type
|
||||
var req *http.Request
|
||||
@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
|
||||
|
||||
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
|
||||
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
|
||||
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
return s.testGeminiAccountConnection(c, account, modelID, prompt)
|
||||
}
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
||||
func createGeminiTestPayload() []byte {
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API.
|
||||
// Image models use the image-generation path so the frontend can preview the returned image.
|
||||
func createGeminiTestPayload(modelID string, prompt string) []byte {
|
||||
if isImageGenerationModel(modelID) {
|
||||
imagePrompt := strings.TrimSpace(prompt)
|
||||
if imagePrompt == "" {
|
||||
imagePrompt = defaultGeminiImageTestPrompt
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": imagePrompt},
|
||||
},
|
||||
},
|
||||
},
|
||||
"generationConfig": map[string]any{
|
||||
"responseModalities": []string{"TEXT", "IMAGE"},
|
||||
"imageConfig": map[string]any{
|
||||
"aspectRatio": "1:1",
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(payload)
|
||||
return bytes
|
||||
}
|
||||
|
||||
textPrompt := strings.TrimSpace(prompt)
|
||||
if textPrompt == "" {
|
||||
textPrompt = defaultGeminiTextTestPrompt
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": "hi"},
|
||||
{"text": textPrompt},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
|
||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
|
||||
mimeType, _ := inlineData["mimeType"].(string)
|
||||
data, _ := inlineData["data"].(string)
|
||||
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
|
||||
s.sendEvent(c, TestEvent{
|
||||
Type: "image",
|
||||
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
|
||||
MimeType: mimeType,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
|
||||
ginCtx, _ := gin.CreateTestContext(w)
|
||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
|
||||
|
||||
finishedAt := time.Now()
|
||||
body := w.Body.String()
|
||||
|
||||
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
59
backend/internal/service/account_test_service_gemini_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
|
||||
|
||||
var parsed struct {
|
||||
Contents []struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"contents"`
|
||||
GenerationConfig struct {
|
||||
ResponseModalities []string `json:"responseModalities"`
|
||||
ImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio"`
|
||||
} `json:"imageConfig"`
|
||||
} `json:"generationConfig"`
|
||||
}
|
||||
|
||||
require.NoError(t, json.Unmarshal(payload, &parsed))
|
||||
require.Len(t, parsed.Contents, 1)
|
||||
require.Len(t, parsed.Contents[0].Parts, 1)
|
||||
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
|
||||
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
|
||||
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
|
||||
}
|
||||
|
||||
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ctx, recorder := newSoraTestContext()
|
||||
svc := &AccountTestService{}
|
||||
|
||||
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||
|
||||
err := svc.processGeminiStream(ctx, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := recorder.Body.String()
|
||||
require.Contains(t, body, "\"type\":\"content\"")
|
||||
require.Contains(t, body, "\"text\":\"ok\"")
|
||||
require.Contains(t, body, "\"type\":\"image\"")
|
||||
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
|
||||
require.Contains(t, body, "\"mime_type\":\"image/png\"")
|
||||
}
|
||||
@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
|
||||
mergeAccountExtra(account, updates)
|
||||
if resetAt != nil {
|
||||
account.RateLimitResetAt = resetAt
|
||||
}
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
|
||||
if account == nil || !account.IsOAuth() {
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
return nil, nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
modelID := openaipkg.DefaultTestModel
|
||||
payload := createOpenAITestPayload(modelID, true)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
}
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(resp)
|
||||
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if len(updates) > 0 || resetAt != nil {
|
||||
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
|
||||
return updates, resetAt, nil
|
||||
}
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
|
||||
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
if len(updates) == 0 && resetAt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if len(updates) > 0 {
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
if resetAt != nil {
|
||||
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
|
||||
if resp == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
baseTime := time.Now()
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
|
||||
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
|
||||
if len(updates) > 0 {
|
||||
return updates, resetAt, nil
|
||||
}
|
||||
return nil, resetAt, nil
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||
if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, nil
|
||||
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
|
||||
return updates, err
|
||||
}
|
||||
|
||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||
|
||||
@@ -1,11 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type accountUsageCodexProbeRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
updateExtraCh chan map[string]any
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
if r.updateExtraCh != nil {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtraCh <- copied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected resetAt from exhausted codex headers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repo := &accountUsageCodexProbeRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &AccountUsageService{accountRepo: repo}
|
||||
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||
|
||||
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
||||
}, &resetAt)
|
||||
|
||||
select {
|
||||
case updates := <-repo.updateExtraCh:
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("waiting for codex probe extra persistence timed out")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-repo.rateLimitCh:
|
||||
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
|
||||
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ type AdminService interface {
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// API Key management (admin)
|
||||
@@ -138,9 +139,10 @@ type CreateGroupInput struct {
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
SimulateClaudeMaxEnabled *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
@@ -177,9 +179,10 @@ type UpdateGroupInput struct {
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
SimulateClaudeMaxEnabled *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
@@ -363,6 +366,10 @@ type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
type proxyQualityTarget struct {
|
||||
Target string
|
||||
URL string
|
||||
@@ -432,16 +439,13 @@ type adminServiceImpl struct {
|
||||
entClient *dbent.Client // 用于开启数据库事务
|
||||
settingService *SettingService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userSubRepo UserSubscriptionRepository
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
@@ -459,6 +463,7 @@ func NewAdminService(
|
||||
entClient *dbent.Client,
|
||||
settingService *SettingService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@@ -476,6 +481,7 @@ func NewAdminService(
|
||||
entClient: entClient,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userSubRepo: userSubRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -857,6 +863,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
if input.MCPXMLInject != nil {
|
||||
mcpXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
simulateClaudeMaxEnabled := false
|
||||
if input.SimulateClaudeMaxEnabled != nil {
|
||||
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
||||
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
||||
}
|
||||
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||
var accountIDsToCopy []int64
|
||||
@@ -913,6 +926,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||
ModelRouting: input.ModelRouting,
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||
@@ -1124,6 +1138,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.MCPXMLInject != nil {
|
||||
group.MCPXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
if input.SimulateClaudeMaxEnabled != nil {
|
||||
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
||||
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
||||
}
|
||||
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
||||
}
|
||||
if group.Platform != PlatformAnthropic {
|
||||
group.SimulateClaudeMaxEnabled = false
|
||||
}
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
if input.SupportedModelScopes != nil {
|
||||
@@ -1241,6 +1264,13 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
@@ -1277,9 +1307,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
if group.Status != StatusActive {
|
||||
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
|
||||
}
|
||||
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
|
||||
// 订阅类型分组:用户须持有该分组的有效订阅才可绑定
|
||||
if group.IsSubscriptionType() {
|
||||
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
|
||||
if s.userSubRepo == nil {
|
||||
return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured")
|
||||
}
|
||||
if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil {
|
||||
if errors.Is(err, ErrSubscriptionNotFound) {
|
||||
return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
gid := *groupID
|
||||
@@ -1287,7 +1325,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
|
||||
apiKey.Group = group
|
||||
|
||||
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
|
||||
if group.IsExclusive {
|
||||
if group.IsExclusive && !group.IsSubscriptionType() {
|
||||
opCtx := ctx
|
||||
var tx *dbent.Tx
|
||||
if s.entClient == nil {
|
||||
|
||||
@@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context,
|
||||
return s.addGroupErr
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
|
||||
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type apiKeyRepoStubForGroupUpdate struct {
|
||||
@@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
type userSubRepoStubForGroupUpdate struct {
|
||||
userSubRepoNoop
|
||||
getActiveSub *UserSubscription
|
||||
getActiveErr error
|
||||
called bool
|
||||
calledUserID int64
|
||||
calledGroupID int64
|
||||
}
|
||||
|
||||
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
s.called = true
|
||||
s.calledUserID = userID
|
||||
s.calledGroupID = groupID
|
||||
if s.getActiveErr != nil {
|
||||
return nil, s.getActiveErr
|
||||
}
|
||||
if s.getActiveSub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *s.getActiveSub
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||
|
||||
// 无有效订阅时应拒绝绑定
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
|
||||
require.True(t, userSubRepo.called)
|
||||
require.Equal(t, int64(42), userSubRepo.calledUserID)
|
||||
require.Equal(t, int64(10), userSubRepo.calledGroupID)
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||
|
||||
// 订阅类型分组应被阻止绑定
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
|
||||
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
userSubRepo := &userSubRepoStubForGroupUpdate{
|
||||
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.True(t, userSubRepo.called)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||
s.getByIDsCalled = true
|
||||
s.getByIDsIDs = append([]int64{}, ids...)
|
||||
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
|
||||
@@ -785,3 +785,57 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
enabled := true
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "openai-group",
|
||||
Platform: PlatformOpenAI,
|
||||
SimulateClaudeMaxEnabled: &enabled,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "openai-group",
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
enabled := true
|
||||
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
SimulateClaudeMaxEnabled: &enabled,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "anthropic-group",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
|
||||
}
|
||||
|
||||
@@ -68,6 +68,10 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
panic("unexpected GetByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
@@ -1673,7 +1673,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
var clientDisconnect bool
|
||||
if claudeReq.Stream {
|
||||
// 客户端要求流式,直接透传转换
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
@@ -1683,7 +1683,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
clientDisconnect = streamRes.clientDisconnect
|
||||
} else {
|
||||
// 客户端要求非流式,收集流式响应后转换返回
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
@@ -1692,6 +1692,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
}
|
||||
|
||||
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
|
||||
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -2164,6 +2167,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
|
||||
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
|
||||
signatureCheckBody := respBody
|
||||
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
|
||||
signatureCheckBody = unwrapped
|
||||
}
|
||||
if resp.StatusCode == http.StatusBadRequest &&
|
||||
s.settingService != nil &&
|
||||
s.settingService.IsSignatureRectifierEnabled(ctx) &&
|
||||
isSignatureRelatedError(signatureCheckBody) &&
|
||||
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
|
||||
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "signature_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||
|
||||
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||
if wrapErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
account: account,
|
||||
proxyURL: proxyURL,
|
||||
accessToken: accessToken,
|
||||
action: upstreamAction,
|
||||
body: retryWrappedBody,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
accountRepo: s.accountRepo,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
})
|
||||
if retryErr == nil {
|
||||
retryResp := retryResult.resp
|
||||
if retryResp.StatusCode < 400 {
|
||||
resp = retryResp
|
||||
} else {
|
||||
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
retryOpsBody := retryRespBody
|
||||
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
|
||||
retryOpsBody = retryUnwrapped
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: retryResp.StatusCode,
|
||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||
Kind: "signature_retry",
|
||||
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
|
||||
Detail: s.getUpstreamErrorDetail(retryOpsBody),
|
||||
})
|
||||
respBody = retryRespBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
|
||||
}
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
}
|
||||
} else {
|
||||
if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: http.StatusServiceUnavailable,
|
||||
Kind: "failover",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "signature_retry_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||
})
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
|
||||
}
|
||||
}
|
||||
|
||||
// fallback 成功:继续按正常响应处理
|
||||
if resp.StatusCode < 400 {
|
||||
goto handleSuccess
|
||||
@@ -3489,7 +3598,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
|
||||
|
||||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
@@ -3647,6 +3756,9 @@ returnResponse:
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||
}
|
||||
|
||||
// Claude Max cache billing simulation (non-streaming)
|
||||
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
|
||||
|
||||
c.Data(http.StatusOK, "application/json", claudeResp)
|
||||
|
||||
// 转换为 service.ClaudeUsage
|
||||
@@ -3661,7 +3773,7 @@ returnResponse:
|
||||
}
|
||||
|
||||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -3674,6 +3786,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
|
||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
|
||||
|
||||
var firstTokenMs *int
|
||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
type queuedHTTPUpstreamStub struct {
|
||||
responses []*http.Response
|
||||
errors []error
|
||||
requestBodies [][]byte
|
||||
callCount int
|
||||
onCall func(*http.Request, *queuedHTTPUpstreamStub)
|
||||
}
|
||||
|
||||
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
if req != nil && req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
s.requestBodies = append(s.requestBodies, body)
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
} else {
|
||||
s.requestBodies = append(s.requestBodies, nil)
|
||||
}
|
||||
|
||||
idx := s.callCount
|
||||
s.callCount++
|
||||
if s.onCall != nil {
|
||||
s.onCall(req, s)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if idx < len(s.responses) {
|
||||
resp = s.responses[idx]
|
||||
}
|
||||
var err error
|
||||
if idx < len(s.errors) {
|
||||
err = s.errors[idx]
|
||||
}
|
||||
if resp == nil && err == nil {
|
||||
return nil, errors.New("unexpected upstream call")
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, concurrency)
|
||||
}
|
||||
|
||||
type antigravitySettingRepoStub struct{}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req-sig-1"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req-sig-2"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
const originalModel = "gemini-3.1-pro-preview"
|
||||
const mappedModel = "gemini-3.1-pro-high"
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "acc-gemini-signature",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
originalModel: mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||
|
||||
firstReq := string(upstream.requestBodies[0])
|
||||
secondReq := string(upstream.requestBodies[1])
|
||||
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
|
||||
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
|
||||
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
|
||||
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
|
||||
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, events)
|
||||
require.Equal(t, "signature_error", events[0].Kind)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||
|
||||
const originalModel = "gemini-3.1-pro-preview"
|
||||
const mappedModel = "gemini-3.1-pro-high"
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Name: "acc-gemini-signature-failover",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
originalModel: mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req-sig-failover-1"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||
},
|
||||
},
|
||||
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
|
||||
if stub.callCount != 1 {
|
||||
return
|
||||
}
|
||||
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
|
||||
account.Extra = map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
mappedModel: map[string]any{
|
||||
"rate_limit_reset_at": futureResetAt,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
|
||||
require.Nil(t, result)
|
||||
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
|
||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.ForceCacheBilling)
|
||||
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 2)
|
||||
require.Equal(t, "signature_error", events[0].Kind)
|
||||
require.Equal(t, "failover", events[1].Kind)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||
@@ -710,7 +922,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -787,7 +999,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -990,7 +1202,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -1022,7 +1234,7 @@ func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
|
||||
fmt.Fprintln(pw, "")
|
||||
}()
|
||||
|
||||
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
_ = pr.Close()
|
||||
|
||||
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
||||
@@ -1054,7 +1266,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
||||
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
@@ -59,9 +59,10 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
|
||||
@@ -244,6 +244,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||
@@ -303,6 +304,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
||||
return d.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
|
||||
// It is intentionally small so repositories can return it from a single SQL statement.
|
||||
type APIKeyQuotaUsageState struct {
|
||||
QuotaUsed float64
|
||||
Quota float64
|
||||
Key string
|
||||
Status string
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
|
||||
return nil
|
||||
}
|
||||
|
||||
type quotaStateReader interface {
|
||||
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
|
||||
}
|
||||
|
||||
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
|
||||
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("increment quota used: %w", err)
|
||||
}
|
||||
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
|
||||
s.InvalidateAuthCacheByKey(ctx, state.Key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use repository to atomically increment quota_used
|
||||
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||
if err != nil {
|
||||
|
||||
170
backend/internal/service/api_key_service_quota_test.go
Normal file
170
backend/internal/service/api_key_service_quota_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type quotaStateRepoStub struct {
|
||||
quotaBaseAPIKeyRepoStub
|
||||
stateCalls int
|
||||
state *APIKeyQuotaUsageState
|
||||
stateErr error
|
||||
}
|
||||
|
||||
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
|
||||
s.stateCalls++
|
||||
if s.stateErr != nil {
|
||||
return nil, s.stateErr
|
||||
}
|
||||
if s.state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
out := *s.state
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
type quotaStateCacheStub struct {
|
||||
deleteAuthKeys []string
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type quotaBaseAPIKeyRepoStub struct {
|
||||
getByIDCalls int
|
||||
}
|
||||
|
||||
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
|
||||
s.getByIDCalls++
|
||||
return nil, nil
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||
panic("unexpected GetKeyAndOwnerID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
|
||||
repo := "aStateRepoStub{
|
||||
state: &APIKeyQuotaUsageState{
|
||||
QuotaUsed: 12,
|
||||
Quota: 10,
|
||||
Key: "sk-test-quota",
|
||||
Status: StatusAPIKeyQuotaExhausted,
|
||||
},
|
||||
}
|
||||
cache := "aStateCacheStub{}
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.stateCalls)
|
||||
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
|
||||
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -21,24 +22,25 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
|
||||
ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
@@ -58,6 +60,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
entClient *dbent.Client
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
@@ -76,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
entClient *dbent.Client,
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
@@ -88,6 +92,7 @@ func NewAuthService(
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
entClient: entClient,
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
@@ -523,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@@ -552,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return nil, nil, ErrOAuthInvitationRequired
|
||||
}
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
@@ -579,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
if s.entClient != nil && invitationRedeemCode != nil {
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
if err := s.userRepo.Create(txCtx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -618,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||
|
||||
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||
|
||||
type pendingOAuthClaims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Purpose string `json:"purpose"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||
// while waiting for the user to supply an invitation code.
|
||||
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: email,
|
||||
Username: username,
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||
if len(tokenStr) > maxTokenLength {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
if claims.Purpose != pendingOAuthPurpose {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
return claims.Email, claims.Username, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
|
||||
146
backend/internal/service/auth_service_pending_oauth_test.go
Normal file
146
backend/internal/service/auth_service_pending_oauth_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthServiceForPendingOAuthTest() *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-pending-oauth",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
|
||||
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
email, username, err := svc.VerifyPendingOAuthToken(token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user@example.com", email)
|
||||
require.Equal(t, "alice", username)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
|
||||
accessToken, err := svc.GenerateToken(&User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: RoleUser,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "some_other_purpose",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "", // 旧 token 无此字段,反序列化后为零值
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
past := time.Now().Add(-1 * time.Hour)
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(past),
|
||||
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
|
||||
other := NewAuthService(nil, nil, nil, nil, &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "other-secret"},
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
_, _, err = svc.VerifyPendingOAuthToken(token)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
giant := make([]byte, maxTokenLength+1)
|
||||
for i := range giant {
|
||||
giant[i] = 'a'
|
||||
}
|
||||
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
@@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
}
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
|
||||
@@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
|
||||
turnstileService := NewTurnstileService(settingService, verifier)
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient
|
||||
&userRepoStub{},
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
|
||||
450
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
450
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type claudeMaxCacheBillingOutcome struct {
|
||||
Simulated bool
|
||||
}
|
||||
|
||||
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||
var out claudeMaxCacheBillingOutcome
|
||||
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
|
||||
return out
|
||||
}
|
||||
|
||||
resolvedModel := strings.TrimSpace(model)
|
||||
if resolvedModel == "" && parsed != nil {
|
||||
resolvedModel = strings.TrimSpace(parsed.Model)
|
||||
}
|
||||
|
||||
if hasCacheCreationTokens(*usage) {
|
||||
// Upstream already returned cache creation usage; keep original usage.
|
||||
return out
|
||||
}
|
||||
|
||||
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
|
||||
return out
|
||||
}
|
||||
beforeInputTokens := usage.InputTokens
|
||||
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
|
||||
if out.Simulated {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
||||
resolvedModel,
|
||||
accountID,
|
||||
beforeInputTokens,
|
||||
usage.InputTokens,
|
||||
usage.CacheCreation1hTokens,
|
||||
)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isClaudeFamilyModel(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(normalized, "claude-")
|
||||
}
|
||||
|
||||
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
||||
return false
|
||||
}
|
||||
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
|
||||
}
|
||||
|
||||
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
|
||||
resolvedModel := model
|
||||
if resolvedModel == "" && parsed != nil {
|
||||
resolvedModel = parsed.Model
|
||||
}
|
||||
if !isClaudeFamilyModel(resolvedModel) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasCacheCreationTokens(usage ClaudeUsage) bool {
|
||||
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil {
|
||||
return false
|
||||
}
|
||||
if !shouldApplyClaudeMaxBillingRules(input) {
|
||||
return false
|
||||
}
|
||||
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage.InputTokens <= 0 {
|
||||
return false
|
||||
}
|
||||
if hasCacheCreationTokens(usage) {
|
||||
return false
|
||||
}
|
||||
if !hasClaudeCacheSignals(parsed) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
|
||||
changed = false
|
||||
}
|
||||
}()
|
||||
return projectUsageToClaudeMax1H(usage, parsed)
|
||||
}
|
||||
|
||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if totalWindowTokens <= 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
|
||||
if simulatedInputTokens <= 0 {
|
||||
simulatedInputTokens = 1
|
||||
}
|
||||
if simulatedInputTokens >= totalWindowTokens {
|
||||
simulatedInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
|
||||
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
||||
if usage.InputTokens == simulatedInputTokens &&
|
||||
usage.CacheCreation5mTokens == 0 &&
|
||||
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
||||
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
||||
return false
|
||||
}
|
||||
|
||||
usage.InputTokens = simulatedInputTokens
|
||||
usage.CacheCreation5mTokens = 0
|
||||
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
||||
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
||||
return true
|
||||
}
|
||||
|
||||
type claudeCacheProjection struct {
|
||||
HasBreakpoint bool
|
||||
BreakpointCount int
|
||||
TotalEstimatedTokens int
|
||||
TailEstimatedTokens int
|
||||
}
|
||||
|
||||
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
||||
if totalWindowTokens <= 1 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
projection := analyzeClaudeCacheProjection(parsed)
|
||||
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
totalEstimate := int64(projection.TotalEstimatedTokens)
|
||||
tailEstimate := int64(projection.TailEstimatedTokens)
|
||||
if tailEstimate > totalEstimate {
|
||||
tailEstimate = totalEstimate
|
||||
}
|
||||
|
||||
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
|
||||
if scaled <= 0 {
|
||||
scaled = 1
|
||||
}
|
||||
if scaled >= int64(totalWindowTokens) {
|
||||
scaled = int64(totalWindowTokens - 1)
|
||||
}
|
||||
return int(scaled)
|
||||
}
|
||||
|
||||
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||
return true
|
||||
}
|
||||
return countExplicitCacheBreakpoints(parsed) > 0
|
||||
}
|
||||
|
||||
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
|
||||
if parsed == nil || len(parsed.Body) == 0 {
|
||||
return false
|
||||
}
|
||||
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
|
||||
return strings.EqualFold(cacheType, "ephemeral")
|
||||
}
|
||||
|
||||
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
|
||||
var projection claudeCacheProjection
|
||||
if parsed == nil {
|
||||
return projection
|
||||
}
|
||||
|
||||
total := 0
|
||||
lastBreakpointAt := -1
|
||||
|
||||
switch system := parsed.System.(type) {
|
||||
case string:
|
||||
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
|
||||
case []any:
|
||||
for _, raw := range system {
|
||||
block, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
total += estimateClaudeBlockTokens(block)
|
||||
if hasEphemeralCacheControl(block) {
|
||||
lastBreakpointAt = total
|
||||
projection.BreakpointCount++
|
||||
projection.HasBreakpoint = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rawMsg := range parsed.Messages {
|
||||
total += claudeMaxMessageOverheadTokens
|
||||
msg, ok := rawMsg.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
content, exists := msg["content"]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
|
||||
total += msgTokens
|
||||
if msgBreakCount > 0 {
|
||||
lastBreakpointAt = total - msgTokens + msgLastBreak
|
||||
projection.BreakpointCount += msgBreakCount
|
||||
projection.HasBreakpoint = true
|
||||
}
|
||||
}
|
||||
|
||||
if total <= 0 {
|
||||
total = 1
|
||||
}
|
||||
projection.TotalEstimatedTokens = total
|
||||
|
||||
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
|
||||
tail := total - lastBreakpointAt
|
||||
if tail <= 0 {
|
||||
tail = 1
|
||||
}
|
||||
projection.TailEstimatedTokens = tail
|
||||
return projection
|
||||
}
|
||||
|
||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||
tail := estimateLastUserMessageTokens(parsed)
|
||||
if tail <= 0 {
|
||||
tail = 1
|
||||
}
|
||||
projection.HasBreakpoint = true
|
||||
projection.BreakpointCount = 1
|
||||
projection.TailEstimatedTokens = tail
|
||||
}
|
||||
return projection
|
||||
}
|
||||
|
||||
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
total := 0
|
||||
if system, ok := parsed.System.([]any); ok {
|
||||
for _, raw := range system {
|
||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, rawMsg := range parsed.Messages {
|
||||
msg, ok := rawMsg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msg["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, raw := range content {
|
||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func hasEphemeralCacheControl(block map[string]any) bool {
|
||||
if block == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := block["cache_control"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch cc := raw.(type) {
|
||||
case map[string]any:
|
||||
cacheType, _ := cc["type"].(string)
|
||||
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
|
||||
case map[string]string:
|
||||
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
return estimateClaudeTextTokens(value), -1, 0
|
||||
case []any:
|
||||
total := 0
|
||||
lastBreak := -1
|
||||
breaks := 0
|
||||
for _, raw := range value {
|
||||
block, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
total += estimateClaudeBlockTokens(block)
|
||||
if hasEphemeralCacheControl(block) {
|
||||
lastBreak = total
|
||||
breaks++
|
||||
}
|
||||
}
|
||||
return total, lastBreak, breaks
|
||||
default:
|
||||
return estimateStructuredTokens(value), -1, 0
|
||||
}
|
||||
}
|
||||
|
||||
func estimateClaudeBlockTokens(block map[string]any) int {
|
||||
if block == nil {
|
||||
return claudeMaxUnknownContentTokens
|
||||
}
|
||||
tokens := claudeMaxBlockOverheadTokens
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(text)
|
||||
}
|
||||
case "tool_result":
|
||||
if content, ok := block["content"]; ok {
|
||||
nested, _, _ := estimateClaudeContentTokens(content)
|
||||
tokens += nested
|
||||
}
|
||||
case "tool_use":
|
||||
if name, ok := block["name"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(name)
|
||||
}
|
||||
if input, ok := block["input"]; ok {
|
||||
tokens += estimateStructuredTokens(input)
|
||||
}
|
||||
default:
|
||||
if text, ok := block["text"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(text)
|
||||
} else if content, ok := block["content"]; ok {
|
||||
nested, _, _ := estimateClaudeContentTokens(content)
|
||||
tokens += nested
|
||||
}
|
||||
}
|
||||
if tokens <= claudeMaxBlockOverheadTokens {
|
||||
tokens += claudeMaxUnknownContentTokens
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
|
||||
if parsed == nil || len(parsed.Messages) == 0 {
|
||||
return 0
|
||||
}
|
||||
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
||||
msg, ok := parsed.Messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
||||
continue
|
||||
}
|
||||
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
|
||||
return claudeMaxMessageOverheadTokens + tokens
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func estimateStructuredTokens(v any) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
raw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return claudeMaxUnknownContentTokens
|
||||
}
|
||||
return estimateClaudeTextTokens(string(raw))
|
||||
}
|
||||
|
||||
func estimateClaudeTextTokens(text string) int {
|
||||
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
|
||||
return tokens
|
||||
}
|
||||
return estimateClaudeTextTokensHeuristic(text)
|
||||
}
|
||||
|
||||
func estimateClaudeTextTokensHeuristic(text string) int {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
||||
if normalized == "" {
|
||||
return 0
|
||||
}
|
||||
asciiChars := 0
|
||||
nonASCIIChars := 0
|
||||
for _, r := range normalized {
|
||||
if r <= 127 {
|
||||
asciiChars++
|
||||
} else {
|
||||
nonASCIIChars++
|
||||
}
|
||||
}
|
||||
tokens := nonASCIIChars
|
||||
if asciiChars > 0 {
|
||||
tokens += (asciiChars + 3) / 4
|
||||
}
|
||||
if words := len(strings.Fields(normalized)); words > tokens {
|
||||
tokens = words
|
||||
}
|
||||
if tokens <= 0 {
|
||||
return 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
156
backend/internal/service/claude_max_simulation_test.go
Normal file
156
backend/internal/service/claude_max_simulation_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: 1200,
|
||||
CacheCreationInputTokens: 0,
|
||||
CacheCreation5mTokens: 0,
|
||||
CacheCreation1hTokens: 0,
|
||||
}
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": strings.Repeat("cached context ", 200),
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "summarize quickly",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed := projectUsageToClaudeMax1H(usage, parsed)
|
||||
if !changed {
|
||||
t.Fatalf("expected usage to be projected")
|
||||
}
|
||||
|
||||
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if total != 1200 {
|
||||
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
|
||||
}
|
||||
if usage.CacheCreation5mTokens != 0 {
|
||||
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
|
||||
}
|
||||
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
|
||||
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
|
||||
}
|
||||
if usage.InputTokens > 100 {
|
||||
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
|
||||
}
|
||||
if usage.CacheCreation1hTokens <= 0 {
|
||||
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
|
||||
}
|
||||
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
|
||||
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-opus-4-5",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "build context",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "what is failing now",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||
if got1 != got2 {
|
||||
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
|
||||
group := &Group{
|
||||
Platform: PlatformAnthropic,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
}
|
||||
input := &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 3000,
|
||||
CacheCreationInputTokens: 0,
|
||||
CacheCreation5mTokens: 0,
|
||||
CacheCreation1hTokens: 0,
|
||||
},
|
||||
},
|
||||
ParsedRequest: &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "cached",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "tail",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{Group: group},
|
||||
}
|
||||
|
||||
if !shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=true for claude group with cache signal")
|
||||
}
|
||||
|
||||
input.ParsedRequest = &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "no cache signal"},
|
||||
},
|
||||
}
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=false when request has no cache signal")
|
||||
}
|
||||
|
||||
input.ParsedRequest = &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "cached",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
input.Result.Usage.CacheCreationInputTokens = 100
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=false when cache creation already exists")
|
||||
}
|
||||
}
|
||||
41
backend/internal/service/claude_tokenizer.go
Normal file
41
backend/internal/service/claude_tokenizer.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
tiktoken "github.com/pkoukk/tiktoken-go"
|
||||
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
|
||||
)
|
||||
|
||||
var (
|
||||
claudeTokenizerOnce sync.Once
|
||||
claudeTokenizer *tiktoken.Tiktoken
|
||||
)
|
||||
|
||||
func getClaudeTokenizer() *tiktoken.Tiktoken {
|
||||
claudeTokenizerOnce.Do(func() {
|
||||
// Use offline loader to avoid runtime dictionary download.
|
||||
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
|
||||
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
|
||||
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
||||
if err != nil {
|
||||
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
||||
}
|
||||
if err == nil {
|
||||
claudeTokenizer = enc
|
||||
}
|
||||
})
|
||||
return claudeTokenizer
|
||||
}
|
||||
|
||||
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
|
||||
enc := getClaudeTokenizer()
|
||||
if enc == nil {
|
||||
return 0, false
|
||||
}
|
||||
tokens := len(enc.EncodeOrdinary(text))
|
||||
if tokens <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return tokens, true
|
||||
}
|
||||
@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
|
||||
|
||||
// 清理过期槽位(后台任务)
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
|
||||
// 启动时清理旧进程遗留槽位与等待计数
|
||||
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
|
||||
return "r" + strconv.FormatUint(fallback, 36)
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking.
|
||||
// Format: {process_random_prefix}-{base36_counter}
|
||||
func RequestIDPrefix() string {
|
||||
return requestIDPrefix
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
seq := requestIDCounter.Add(1)
|
||||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
@@ -331,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
||||
}()
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
|
||||
// Uses a detached context with timeout to prevent HTTP request cancellation from
|
||||
// causing the entire batch to fail (which would show all concurrency as 0).
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
@@ -344,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||
|
||||
// Use a detached context so that a cancelled HTTP request doesn't cause
|
||||
// the Redis pipeline to fail and return all-zero concurrency counts.
|
||||
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user