Compare commits

..

63 Commits

Author SHA1 Message Date
Wesley Liddick
826090e099 Merge pull request #946 from StarryKira/antigravity-gemini-thought-signature-fix
fix Antigravity gemini thought signature fix
2026-03-12 13:51:46 +08:00
Wesley Liddick
7399de6ecc Merge pull request #938 from xvhuan/fix/account-extra-scheduler-pressure-20260311
精准收紧 accounts.extra 观测字段触发的调度重建
2026-03-12 13:51:00 +08:00
haruka
25cb5e7505 fix 第一次 400,第二次触发切账号信号 2026-03-12 11:30:53 +08:00
ius
5c13ec3121 Fix lint after rebasing PR #938 branch 2026-03-12 11:20:59 +08:00
ius
d8aff3a7e3 Merge origin/main into fix/account-extra-scheduler-pressure-20260311 2026-03-12 11:12:01 +08:00
haruka
f44927b9f8 add test for fix #935 2026-03-12 11:04:14 +08:00
Wesley Liddick
c0110cb5af Merge pull request #941 from CoolCoolTomato/main
fix: 修复gpt-5.2以上模型映射到gpt-5.2以下时verbosity参数引发的报错
2026-03-12 09:35:09 +08:00
Wesley Liddick
1f8e1142a0 Merge pull request #932 from 0xObjc/codex/usage-view-charts
feat(admin): add metric toggle to usage charts
2026-03-12 09:32:40 +08:00
Wesley Liddick
1e51de88d6 Merge pull request #937 from lxohi/fix/anthropic-stream-keepalive
fix: 为 Anthropic Messages API 流式转发添加下游 keepalive ping
2026-03-12 09:30:18 +08:00
Wesley Liddick
30995b5397 Merge pull request #936 from xvhuan/fix/ops-write-pressure-20260311
降低 ops_error_logs 与 scheduler_outbox 的数据库写放大
2026-03-12 09:28:34 +08:00
Wesley Liddick
eb60f67054 Merge pull request #933 from xvhuan/fix/dashboard-read-pressure-20260311
降低 admin/dashboard 读路径压力,避免 snapshot-v2 并发击穿
2026-03-12 09:28:14 +08:00
Wesley Liddick
78193ceec1 Merge pull request #931 from xvhuan/fix/db-write-amplification-20260311
降低 quota 与 Codex 快照热路径的数据库写放大
2026-03-12 09:27:34 +08:00
Wesley Liddick
f0e08e7687 Merge pull request #930 from GuangYiDing/feat/gemini-25-flash-image-support
feat: 修复 Gemini 生图接口并新增前端生图测试能力
2026-03-12 09:27:19 +08:00
Wesley Liddick
10b8259259 Merge pull request #909 from StarryKira/feature/admin-reset-subscription-quota
Feature/管理员可以重置账号额度
2026-03-12 09:26:47 +08:00
CoolCoolTomato
eb0b77bf4d fix: 修复流水线golangci-lint 的 errcheck 2026-03-11 22:56:20 +08:00
shaw
9d81467937 refactor: 重构 Chat Completions 端点,采用类型安全的 Responses API 转换
将 /v1/chat/completions 端点从 ResponseWriter 劫持模式重构为独立的
类型安全转换路径,与 Anthropic Messages 端点架构对齐:

- 在 apicompat 包新增 Chat Completions 完整类型定义和双向转换器
- 新增 ForwardAsChatCompletions service 方法,走 Responses API 上游
- Handler 改为独立的账号选择/failover 循环,不再劫持 Responses handler
- 提取 handleCompatErrorResponse 为 Chat Completions 和 Messages 共用
- 删除旧的 forwardChatCompletions 直传路径及相关死代码
2026-03-11 22:15:32 +08:00
CoolCoolTomato
fd8ccaf01a fix: 修复gpt-5.2以上模型映射到gpt-5.2以下时verbosity参数引发的报错 2026-03-11 21:12:07 +08:00
ius
2b30e3b6d7 Reduce scheduler rebuilds on neutral extra updates 2026-03-11 19:16:19 +08:00
amberwarden
6e90ec6111 fix: 为 Anthropic Messages API 流式转发添加下游 keepalive ping
Anthropic Messages API 的流式转发路径(gateway_service.go)在上游长时间
无数据时(如 Opus extended thinking 阶段)不会向下游发送任何内容,导致
Cloudflare Tunnel 等代理因连接空闲而断开。

复用已有的 StreamKeepaliveInterval 配置(默认 10 秒),在 select 循环中
添加 keepalive 分支,定时发送 Anthropic 原生格式的 ping 事件保活,与
OpenAI 兼容路径的实现模式保持一致。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 18:43:03 +08:00
Wesley Liddick
8dd38f4775 Merge pull request #926 from 7976723/feat/chat-completions-compat-v2
feat: 添加 OpenAI Chat Completions 兼容端点(基于 #648,修复编译错误和运行时 panic)
2026-03-11 17:42:03 +08:00
ius
fbd73f248f Fix ops write pressure integration fixture 2026-03-11 17:40:28 +08:00
Rose Ding
3fcefe6c32 feat: prioritize new gemini image models in frontend 2026-03-11 17:34:44 +08:00
ius
f740d2c291 Reduce ops and scheduler write amplification 2026-03-11 17:32:00 +08:00
Rose Ding
bf6585a40f feat: add gemini image test preview 2026-03-11 17:12:57 +08:00
ius
8c2dd7b3f0 Fix dashboard snapshot lint errors 2026-03-11 16:57:18 +08:00
ius
4167c437a8 Reduce admin dashboard read amplification 2026-03-11 16:46:58 +08:00
Peter
0ddaef3c9a feat(admin): add metric toggle to usage charts 2026-03-11 16:05:27 +08:00
ius
2fc6aaf936 Fix Codex exhausted snapshot propagation 2026-03-11 15:47:39 +08:00
Rose Ding
1c0519f1c7 feat: add gemini 2.5 flash image support 2026-03-11 15:21:52 +08:00
Wesley Liddick
6bbe7800be Merge pull request #908 from wucm667/fix/ops-alert-group-account-metrics
fix: 补充缺失的组级和账户级运维告警指标
2026-03-11 15:04:07 +08:00
ius
2694149489 Reduce DB write amplification on quota and account extra updates 2026-03-11 13:53:19 +08:00
7976723
a17ac50118 fix: 修复 Chat Completions 编译错误和运行时 panic
1. 修复 WriteFilteredHeaders API 不兼容(2处):
   将 s.cfg.Security.ResponseHeaders 改为 s.responseHeaderFilter,
   因为 main 分支已将函数签名改为接受 *responseheaders.CompiledHeaderFilter

2. 修复 writer 生命周期导致的 nil pointer panic:
   ChatCompletions handler 替换了 c.Writer 但未恢复,导致
   OpsErrorLogger 中间件的 defer 释放 opsCaptureWriter 后,
   Logger 中间件调用 c.Writer.Status() 触发空指针解引用。
   通过保存并恢复 originalWriter 修复。

3. 为 chatCompletionsResponseWriter 添加防御性 Status() 和
   Written() 方法,包含 nil 安全检查

4. 恢复 gateway.go 中被误删的 net/http import
2026-03-11 13:49:13 +08:00
7976723
656a77d585 feat: 添加 OpenAI Chat Completions 兼容端点
基于 @yulate 在 PR #648 (commit 0bb6a392) 的工作,解决了与最新
main 分支的合并冲突。

原始功能(@yulate):
- 添加 /v1/chat/completions 和 /chat/completions 兼容端点
- 将 Chat Completions 请求转换为 Responses API 格式并转换回来
- 添加 API Key 直连转发支持
- 包含单元测试

Co-authored-by: yulate <yulate@users.noreply.github.com>
2026-03-11 13:47:37 +08:00
Wesley Liddick
7455476c60 Merge pull request #918 from rickylin047/fix/responses-string-input
fix(openai): convert string input to array for Codex OAuth responses endpoint (fix #919)
2026-03-11 08:54:39 +08:00
Elysia
36cda57c81 fix copilot review issue 2026-03-10 23:59:39 +08:00
rickylin047
9f1f203b84 fix(openai): convert string input to array for Codex OAuth responses endpoint
The ChatGPT backend-api codex/responses endpoint requires `input` to be
an array, but the OpenAI Responses API spec allows it to be a plain string.
When a client sends a string input, sub2api now converts it to the expected
message array format. Empty/whitespace-only strings become an empty array
to avoid triggering a 400 "Input must be a list" error.
2026-03-10 23:43:52 +08:00
haruka
b41a8ca93f add test 2026-03-10 11:33:25 +08:00
wucm667
e3cf0c0e10 fix: 补充缺失的组级和账户级运维告警指标
新增以下运维告警指标类型:
- group_available_accounts: 组内可用账户数
- group_available_ratio: 组内可用账户比例
- group_rate_limit_ratio: 组内限速账户比例
- account_rate_limited_count: 限速账户数
- account_error_count: 错误账户数
- account_error_ratio: 错误账户比例
- overload_account_count: 过载账户数

包含比例和计数类指标的评估逻辑,并注册新的百分比类指标用于阈值校验。
2026-03-10 11:29:31 +08:00
haruka
de18bce9aa feat: add admin reset subscription quota endpoint and UI
- Add AdminResetQuota service method to reset daily/weekly usage windows
- Add POST /api/v1/admin/subscriptions/:id/reset-quota handler and route
- Add resetQuota API function in frontend subscriptions client
- Add reset quota button, confirmation dialog, and handlers in SubscriptionsView
- Add i18n keys for reset quota feature in zh and en locales

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 11:21:11 +08:00
Wesley Liddick
3cc407bc0e Merge pull request #900 from ischanx/feat/admin-bind-subscription-group
feat: 允许管理员为持有有效订阅的用户绑定订阅类型分组
2026-03-10 11:20:47 +08:00
shaw
00a0a12138 feat: Anthropic平台可配置 anthropic-beta 策略 2026-03-10 11:20:10 +08:00
ischanx
b08767a4f9 fix: avoid admin subscription binding regressions 2026-03-10 10:53:54 +08:00
Wesley Liddick
ac6bde7a98 Merge pull request #872 from StarryKira/fix/oauth-linuxdo-invitation-required
fix: Linux.do OAuth 注册支持邀请码两步流程 (fix #836)
2026-03-10 09:10:35 +08:00
Wesley Liddick
d2d41d68dd Merge pull request #894 from touwaeriol/pr/startup-concurrency-cleanup
feat: cleanup stale concurrency slots on startup
2026-03-10 09:08:33 +08:00
Wesley Liddick
944b7f7617 Merge pull request #904 from james-6-23/fix-pool-mode-retry
fix: OpenAI临时性400错误支持池模式同账号重试 & HelpTooltip层级修复
2026-03-10 09:08:12 +08:00
Wesley Liddick
53825eb073 Merge pull request #903 from touwaeriol/fix/openai-responses-sse-max-line-size-v2
fix: use shared max_line_size config for OpenAI Responses SSE scanner
2026-03-10 09:06:48 +08:00
Wesley Liddick
1a7f49513f Merge pull request #896 from touwaeriol/pr/fix-quota-badge-vertical-layout
fix(frontend): stack quota badges vertically in capacity column
2026-03-10 09:04:53 +08:00
Wesley Liddick
885a2ce7ef Merge pull request #893 from touwaeriol/pr/iframe-lang-passthrough
feat(frontend): pass locale to iframe embedded pages via lang parameter
2026-03-10 09:04:37 +08:00
Wesley Liddick
14ba80a0af Merge pull request #897 from DaydreamCoding/feat/batch-reset-and-openai-jwt
feat: OpenAI 账号信息增强 & 批量操作支持
2026-03-10 09:03:59 +08:00
kyx236
5fa22fdf82 fix: OpenAI临时性400错误支持池模式同账号重试 & HelpTooltip层级修复
1. 识别OpenAI "An error occurred while processing your request" 临时性400错误
   并触发failover,同时在池模式下标记RetryableOnSameAccount,允许同账号重试
2. ForwardAsAnthropic路径同步支持临时性400错误的识别和同账号重试
3. HelpTooltip组件使用Teleport渲染到body,修复在dialog内被裁切的问题
2026-03-10 03:00:58 +08:00
erio
bcaae2eb91 fix: use shared max_line_size config for OpenAI Responses SSE scanner
Two SSE scanners in openai_gateway_messages.go were hardcoded to 1MB
while all other scanners use defaultMaxLineSize (500MB) with config
override. This caused Responses API streams to fail on large payloads.
2026-03-10 02:50:04 +08:00
ischanx
767a41e263 feat: 允许管理员为持有有效订阅的用户绑定订阅类型分组
之前管理员无法通过 API 密钥管理将用户绑定到订阅类型分组(直接返回错误)。
现在改为检查用户是否持有该分组的有效订阅,有则允许绑定,无则拒绝。

- admin_service: 新增 userSubRepo 依赖,替换硬拒绝为订阅校验
- admin_service: 区分 ErrSubscriptionNotFound 和内部错误,避免 DB 故障被误报
- wire_gen/api_contract_test: 同步新增参数
- UserApiKeysModal: 管理员分组下拉不再过滤订阅类型分组

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 00:51:43 +08:00
QTom
252d6c5301 feat: 支持批量重置状态和批量刷新令牌
- 提取 refreshSingleAccount 私有方法复用单账号刷新逻辑
- 新增 BatchClearError handler (POST /admin/accounts/batch-clear-error)
- 新增 BatchRefresh handler (POST /admin/accounts/batch-refresh)
- 前端 AccountBulkActionsBar 添加批量重置状态/刷新令牌按钮
- AccountsView 添加 handler 支持 partial success 反馈
- i18n 中英文补充批量操作相关翻译
2026-03-09 21:54:27 +08:00
QTom
7a4e65ad4b feat: 导入账号时 best-effort 从 id_token 提取用户信息
提取 DecodeIDToken(跳过过期校验)供导入场景使用,
ParseIDToken 复用它并保留原有过期检查行为。
导入 OpenAI/Sora OAuth 账号时自动补充缺失的 email、
plan_type、chatgpt_account_id 等字段,不覆盖已有值。
2026-03-09 21:53:46 +08:00
QTom
a582aa89a9 feat: 从 OpenAI JWT 提取 chatgpt_plan_type 并在前端展示
OAuth 授权和 token 刷新时从 id_token 的 OpenAI auth claim 中
提取 chatgpt_plan_type(plus/team/pro/free),存入 credentials,
账号管理页面 PlatformTypeBadge 显示订阅类型。
2026-03-09 21:53:46 +08:00
erio
acefa1da12 fix(frontend): stack quota badges vertically in capacity column
QuotaBadge components (daily/weekly/total) were wrapped in a
horizontal flex container, making them visually inconsistent with
other capacity badges (concurrency, window cost, sessions, RPM)
which stack vertically. Remove the wrapper so all badges align
consistently.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 20:18:12 +08:00
erio
a88698f3fc feat: cleanup stale concurrency slots on startup
When the service restarts, concurrency slots from the old process
remain in Redis, causing phantom occupancy. On startup, scan all
concurrency sorted sets and remove members with non-current process
prefix, then clear orphaned wait queue counters.

Uses Go-side SCAN to discover keys (compatible with Redis client
prefix hooks in tests), then passes them to a Lua script for
atomic member-level cleanup.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 19:55:18 +08:00
erio
ebc6755b33 feat(frontend): pass locale to iframe embedded pages via lang parameter
Embedded pages (purchase subscription, custom pages) now receive the
current user locale through a `lang` URL parameter, allowing iframe
content to match the user's language preference.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 19:38:23 +08:00
shaw
c8eff34388 chore: update readme 2026-03-09 16:25:56 +08:00
shaw
f19b03825b chore: update docs 2026-03-09 16:11:44 +08:00
Elysia
b43ee62947 fix CI/CD Error 2026-03-09 13:13:39 +08:00
Elysia
106b20cdbf fix claudecode review bug 2026-03-09 01:18:49 +08:00
Elysia
c069b3b1e8 fix issue #836 linux.do注册无需邀请码 2026-03-09 00:35:34 +08:00
131 changed files with 8363 additions and 646 deletions

View File

@@ -150,14 +150,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# Start services
docker-compose -f docker-compose.local.yml up -d
docker-compose up -d
# View logs
docker-compose -f docker-compose.local.yml logs -f sub2api
docker-compose logs -f sub2api
```
**What the script does:**
- Downloads `docker-compose.local.yml` and `.env.example`
- Downloads `docker-compose.local.yml` (saved as `docker-compose.yml`) and `.env.example`
- Generates secure credentials (JWT_SECRET, TOTP_ENCRYPTION_KEY, POSTGRES_PASSWORD)
- Creates `.env` file with auto-generated secrets
- Creates data directories (uses local directories for easy backup/migration)
@@ -522,6 +522,28 @@ sub2api/
└── install.sh # One-click installation script
```
## Disclaimer
> **Please read carefully before using this project:**
>
> :rotating_light: **Terms of Service Risk**: Using this project may violate Anthropic's Terms of Service. Please read Anthropic's user agreement carefully before use. All risks arising from the use of this project are borne solely by the user.
>
> :book: **Disclaimer**: This project is for technical learning and research purposes only. The author assumes no responsibility for account suspension, service interruption, or any other losses caused by the use of this project.
---
## Star History
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
</picture>
</a>
---
## License
MIT License

View File

@@ -154,14 +154,14 @@ mkdir -p sub2api-deploy && cd sub2api-deploy
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash
# 启动服务
docker-compose -f docker-compose.local.yml up -d
docker-compose up -d
# 查看日志
docker-compose -f docker-compose.local.yml logs -f sub2api
docker-compose logs -f sub2api
```
**脚本功能:**
- 下载 `docker-compose.local.yml` `.env.example`
- 下载 `docker-compose.local.yml`(本地保存为 `docker-compose.yml``.env.example`
- 自动生成安全凭证JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD
- 创建 `.env` 文件并填充自动生成的密钥
- 创建数据目录(使用本地目录,便于备份和迁移)
@@ -588,6 +588,28 @@ sub2api/
└── install.sh # 一键安装脚本
```
## 免责声明
> **使用本项目前请仔细阅读:**
>
> :rotating_light: **服务条款风险**: 使用本项目可能违反 Anthropic 的服务条款。请在使用前仔细阅读 Anthropic 的用户协议,使用本项目的一切风险由用户自行承担。
>
> :book: **免责声明**: 本项目仅供技术学习和研究使用,作者不对因使用本项目导致的账户封禁、服务中断或其他损失承担任何责任。
---
## Star History
<a href="https://star-history.com/#Wei-Shaw/sub2api&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Wei-Shaw/sub2api&type=Date" />
</picture>
</a>
---
## 许可证
MIT License

View File

@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
@@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)

View File

@@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
// Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",

View File

@@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
t.Parallel()
cases := map[string]string{
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
"gemini-3-pro-image": "gemini-3.1-flash-image",

View File

@@ -8,6 +8,9 @@ import (
"strings"
"time"
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -292,6 +295,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
}
}
enrichCredentialsFromIDToken(&item)
accountInput := &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
@@ -535,6 +540,57 @@ func defaultProxyName(name string) string {
return name
}
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
// Existing credential values are never overwritten — only missing fields are filled.
func enrichCredentialsFromIDToken(item *DataAccount) {
if item.Credentials == nil {
return
}
// Only enrich OpenAI/Sora OAuth accounts
platform := strings.ToLower(strings.TrimSpace(item.Platform))
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
return
}
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
return
}
idToken, _ := item.Credentials["id_token"].(string)
if strings.TrimSpace(idToken) == "" {
return
}
// DecodeIDToken skips expiry validation — safe for imported data
claims, err := openai.DecodeIDToken(idToken)
if err != nil {
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
return
}
userInfo := claims.GetUserInfo()
if userInfo == nil {
return
}
// Fill missing fields only (never overwrite existing values)
setIfMissing := func(key, value string) {
if value == "" {
return
}
if existing, _ := item.Credentials[key].(string); existing == "" {
item.Credentials[key] = value
}
}
setIfMissing("email", userInfo.Email)
setIfMissing("plan_type", userInfo.PlanType)
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
setIfMissing("organization_id", userInfo.OrganizationID)
}
func normalizeProxyStatus(status string) string {
normalized := strings.TrimSpace(strings.ToLower(status))
switch normalized {

View File

@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"strings"
@@ -18,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -626,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
}
type SyncFromCRSRequest struct {
@@ -656,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
// Error already sent via SSE, just log
return
}
@@ -751,52 +754,31 @@ func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Only refresh OAuth-based accounts (oauth and setup-token)
// refreshSingleAccount refreshes credentials for a single OAuth account.
// Returns (updatedAccount, warning, error) where warning is used for Antigravity ProjectIDMissing scenario.
func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *service.Account) (*service.Account, string, error) {
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
return nil, "", infraerrors.BadRequest("NOT_OAUTH", "cannot refresh non-OAuth account")
}
var newCredentials map[string]any
if account.IsOpenAI() {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
response.ErrorFrom(c, err)
return
return nil, "", err
}
// Build new credentials from token info
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
return nil, "", fmt.Errorf("failed to refresh credentials: %w", err)
}
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
@@ -806,10 +788,9 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
} else if account.Platform == service.PlatformAntigravity {
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(c.Request.Context(), account)
tokenInfo, err := h.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
response.ErrorFrom(c, err)
return
return nil, "", err
}
newCredentials = h.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
@@ -828,37 +809,27 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
// 如果 project_id 获取失败,更新凭证但不标记为 error
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
// 先更新凭证token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
updatedAccount, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if updateErr != nil {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
}
// 不标记为 error只返回警告信息
response.Success(c, gin.H{
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
return updatedAccount, "missing_project_id_temporary", nil
}
// 成功获取到 project_id如果之前是 missing_project_id 错误则清除
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
return
if _, clearErr := h.adminService.ClearAccountError(ctx, account.ID); clearErr != nil {
return nil, "", fmt.Errorf("failed to clear account error: %w", clearErr)
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
tokenInfo, err := h.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
response.ErrorFrom(c, err)
return
return nil, "", err
}
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
@@ -880,20 +851,51 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
updatedAccount, err := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
return nil, "", err
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", updatedAccount.ID, invalidateErr)
}
}
return updatedAccount, "", nil
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
updatedAccount, warning, err := h.refreshSingleAccount(c.Request.Context(), account)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
if h.tokenCacheInvalidator != nil {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
}
if warning == "missing_project_id_temporary" {
response.Success(c, gin.H{
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
"warning": "missing_project_id_temporary",
})
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
@@ -949,14 +951,175 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
// 缓存失效失败只记录日志,不影响主流程
_ = c.Error(invalidateErr)
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
}
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// BatchClearError handles batch clearing account errors
// POST /api/v1/admin/accounts/batch-clear-error
func (h *AccountHandler) BatchClearError(c *gin.Context) {
var req struct {
AccountIDs []int64 `json:"account_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.BadRequest(c, "account_ids is required")
return
}
ctx := c.Request.Context()
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
var successCount, failedCount int
var errors []gin.H
// 注意:所有 goroutine 必须 return nil避免 errgroup cancel 其他并发任务
for _, id := range req.AccountIDs {
accountID := id // 闭包捕获
g.Go(func() error {
account, err := h.adminService.ClearAccountError(gctx, accountID)
if err != nil {
mu.Lock()
failedCount++
errors = append(errors, gin.H{
"account_id": accountID,
"error": err.Error(),
})
mu.Unlock()
return nil
}
// 清除错误后,同时清除 token 缓存
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(gctx, account); invalidateErr != nil {
log.Printf("[WARN] Failed to invalidate token cache for account %d: %v", accountID, invalidateErr)
}
}
mu.Lock()
successCount++
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"total": len(req.AccountIDs),
"success": successCount,
"failed": failedCount,
"errors": errors,
})
}
// BatchRefresh handles batch refreshing account credentials
// POST /api/v1/admin/accounts/batch-refresh
func (h *AccountHandler) BatchRefresh(c *gin.Context) {
var req struct {
AccountIDs []int64 `json:"account_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.BadRequest(c, "account_ids is required")
return
}
ctx := c.Request.Context()
accounts, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 建立已获取账号的 ID 集合,检测缺失的 ID
foundIDs := make(map[int64]bool, len(accounts))
for _, acc := range accounts {
if acc != nil {
foundIDs[acc.ID] = true
}
}
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
var successCount, failedCount int
var errors []gin.H
var warnings []gin.H
// 将不存在的账号 ID 标记为失败
for _, id := range req.AccountIDs {
if !foundIDs[id] {
failedCount++
errors = append(errors, gin.H{
"account_id": id,
"error": "account not found",
})
}
}
// 注意:所有 goroutine 必须 return nil避免 errgroup cancel 其他并发任务
for _, account := range accounts {
acc := account // 闭包捕获
if acc == nil {
continue
}
g.Go(func() error {
_, warning, err := h.refreshSingleAccount(gctx, acc)
mu.Lock()
if err != nil {
failedCount++
errors = append(errors, gin.H{
"account_id": acc.ID,
"error": err.Error(),
})
} else {
successCount++
if warning != "" {
warnings = append(warnings, gin.H{
"account_id": acc.ID,
"warning": warning,
})
}
}
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"total": len(req.AccountIDs),
"success": successCount,
"failed": failedCount,
"errors": errors,
"warnings": warnings,
})
}
// BatchCreate handles batch creating accounts
// POST /api/v1/admin/accounts/batch
func (h *AccountHandler) BatchCreate(c *gin.Context) {

View File

@@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
@@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"models": stats,
@@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"groups": stats,
@@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
@@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
limit = 12
}
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,

View File

@@ -0,0 +1,118 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCacheProbe struct {
service.UsageLogRepository
trendCalls atomic.Int32
usersTrendCalls atomic.Int32
}
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
r.trendCalls.Add(1)
return []usagestats.TrendDataPoint{{
Date: "2026-03-11",
Requests: 1,
TotalTokens: 2,
Cost: 3,
ActualCost: 4,
}}, nil
}
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
limit int,
) ([]usagestats.UserUsageTrendPoint, error) {
r.usersTrendCalls.Add(1)
return []usagestats.UserUsageTrendPoint{{
Date: "2026-03-11",
UserID: 1,
Email: "cache@test.dev",
Requests: 2,
Tokens: 20,
Cost: 2,
ActualCost: 1,
}}, nil
}
func resetDashboardReadCachesForTest() {
dashboardTrendCache = newSnapshotCache(30 * time.Second)
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
}
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
t.Cleanup(resetDashboardReadCachesForTest)
resetDashboardReadCachesForTest()
gin.SetMode(gin.TestMode)
repo := &dashboardUsageRepoCacheProbe{}
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
require.Equal(t, int32(1), repo.trendCalls.Load())
}
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
t.Cleanup(resetDashboardReadCachesForTest)
resetDashboardReadCachesForTest()
gin.SetMode(gin.TestMode)
repo := &dashboardUsageRepoCacheProbe{}
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
}

View File

@@ -0,0 +1,200 @@
package admin
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
var (
dashboardTrendCache = newSnapshotCache(30 * time.Second)
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
)
type dashboardTrendCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
}
type dashboardModelGroupCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
}
type dashboardEntityTrendCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
Limit int `json:"limit"`
}
func cacheStatusValue(hit bool) string {
if hit {
return "hit"
}
return "miss"
}
func mustMarshalDashboardCacheKey(value any) string {
raw, err := json.Marshal(value)
if err != nil {
return ""
}
return string(raw)
}
func snapshotPayloadAs[T any](payload any) (T, error) {
typed, ok := payload.(T)
if !ok {
var zero T
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
}
return typed, nil
}
func (h *DashboardHandler) getUsageTrendCached(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
return trend, hit, err
}
func (h *DashboardHandler) getModelStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
return stats, hit, err
}
func (h *DashboardHandler) getGroupStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.GroupStat, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
return stats, hit, err
}
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
Limit: limit,
})
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
return trend, hit, err
}
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
Limit: limit,
})
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
return trend, hit, err
}

View File

@@ -1,7 +1,9 @@
package admin
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
@@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
})
cacheKey := string(keyRaw)
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
return h.buildSnapshotV2Response(
c.Request.Context(),
startTime,
endTime,
granularity,
filters,
includeStats,
includeTrend,
includeModels,
includeGroups,
includeUsersTrend,
usersTrendLimit,
)
})
if err != nil {
response.Error(c, 500, err.Error())
return
}
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, cached.Payload)
}
func (h *DashboardHandler) buildSnapshotV2Response(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
filters *dashboardSnapshotV2Filters,
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
usersTrendLimit int,
) (*dashboardSnapshotV2Response, error) {
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
@@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
stats, err := h.dashboardService.GetDashboardStats(ctx)
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
return nil, errors.New("failed to get dashboard statistics")
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
@@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
}
if includeTrend {
trend, err := h.dashboardService.GetUsageTrendWithFilters(
c.Request.Context(),
trend, _, err := h.getUsageTrendCached(
ctx,
startTime,
endTime,
granularity,
@@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
return nil, errors.New("failed to get usage trend")
}
resp.Trend = trend
}
if includeModels {
models, err := h.dashboardService.GetModelStatsWithFilters(
c.Request.Context(),
models, _, err := h.getModelStatsCached(
ctx,
startTime,
endTime,
filters.UserID,
@@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
return nil, errors.New("failed to get model statistics")
}
resp.Models = models
}
if includeGroups {
groups, err := h.dashboardService.GetGroupStatsWithFilters(
c.Request.Context(),
groups, _, err := h.getGroupStatsCached(
ctx,
startTime,
endTime,
filters.UserID,
@@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
return nil, errors.New("failed to get group statistics")
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, err := h.dashboardService.GetUserUsageTrend(
c.Request.Context(),
startTime,
endTime,
granularity,
usersTrendLimit,
)
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
return nil, errors.New("failed to get user usage trend")
}
resp.UsersTrend = usersTrend
}
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
return resp, nil
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {

View File

@@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
"group_available_accounts",
"group_available_ratio",
"group_rate_limit_ratio",
"account_rate_limited_count",
"account_error_count",
"account_error_ratio",
"overload_account_count",
}
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
@@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool {
"error_rate",
"upstream_error_rate",
"cpu_usage_percent",
"memory_usage_percent":
"memory_usage_percent",
"group_available_ratio",
"group_rate_limit_ratio",
"account_error_ratio":
return true
default:
return false

View File

@@ -1405,6 +1405,61 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
})
}
// GetBetaPolicySettings 获取 Beta 策略配置
// GET /api/v1/admin/settings/beta-policy
func (h *SettingHandler) GetBetaPolicySettings(c *gin.Context) {
settings, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
rules := make([]dto.BetaPolicyRule, len(settings.Rules))
for i, r := range settings.Rules {
rules[i] = dto.BetaPolicyRule(r)
}
response.Success(c, dto.BetaPolicySettings{Rules: rules})
}
// UpdateBetaPolicySettingsRequest 更新 Beta 策略配置请求
type UpdateBetaPolicySettingsRequest struct {
Rules []dto.BetaPolicyRule `json:"rules"`
}
// UpdateBetaPolicySettings 更新 Beta 策略配置
// PUT /api/v1/admin/settings/beta-policy
func (h *SettingHandler) UpdateBetaPolicySettings(c *gin.Context) {
var req UpdateBetaPolicySettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
rules := make([]service.BetaPolicyRule, len(req.Rules))
for i, r := range req.Rules {
rules[i] = service.BetaPolicyRule(r)
}
settings := &service.BetaPolicySettings{Rules: rules}
if err := h.settingService.SetBetaPolicySettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
// Re-fetch to return updated settings
updated, err := h.settingService.GetBetaPolicySettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
outRules := make([]dto.BetaPolicyRule, len(updated.Rules))
for i, r := range updated.Rules {
outRules[i] = dto.BetaPolicyRule(r)
}
response.Success(c, dto.BetaPolicySettings{Rules: outRules})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`

View File

@@ -7,6 +7,8 @@ import (
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
type snapshotCacheEntry struct {
@@ -19,6 +21,12 @@ type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
sf singleflight.Group
}
type snapshotCacheLoadResult struct {
Entry snapshotCacheEntry
Hit bool
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
@@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
return entry
}
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
if load == nil {
return snapshotCacheEntry{}, false, nil
}
if entry, ok := c.Get(key); ok {
return entry, true, nil
}
if c == nil || key == "" {
payload, err := load()
if err != nil {
return snapshotCacheEntry{}, false, err
}
return c.Set(key, payload), false, nil
}
value, err, _ := c.sf.Do(key, func() (any, error) {
if entry, ok := c.Get(key); ok {
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
}
payload, err := load()
if err != nil {
return nil, err
}
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
})
if err != nil {
return snapshotCacheEntry{}, false, err
}
result, ok := value.(snapshotCacheLoadResult)
if !ok {
return snapshotCacheEntry{}, false, nil
}
return result.Entry, result.Hit, nil
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {

View File

@@ -3,6 +3,8 @@
package admin
import (
"sync"
"sync/atomic"
"testing"
"time"
@@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
require.Empty(t, etag)
}
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
var loads atomic.Int32
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
loads.Add(1)
return map[string]string{"hello": "world"}, nil
})
require.NoError(t, err)
require.False(t, hit)
require.NotEmpty(t, entry.ETag)
require.Equal(t, int32(1), loads.Load())
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
loads.Add(1)
return map[string]string{"unexpected": "value"}, nil
})
require.NoError(t, err)
require.True(t, hit)
require.Equal(t, entry.ETag, entry2.ETag)
require.Equal(t, int32(1), loads.Load())
}
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
var loads atomic.Int32
start := make(chan struct{})
const callers = 8
errCh := make(chan error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for range callers {
go func() {
defer wg.Done()
<-start
_, _, err := c.GetOrLoad("shared", func() (any, error) {
loads.Add(1)
time.Sleep(20 * time.Millisecond)
return "value", nil
})
errCh <- err
}()
}
close(start)
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
require.Equal(t, int32(1), loads.Load())
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string

View File

@@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
})
}
// ResetSubscriptionQuotaRequest represents the reset quota request
type ResetSubscriptionQuotaRequest struct {
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
}
// ResetQuota resets daily and/or weekly usage for a subscription.
// POST /api/v1/admin/subscriptions/:id/reset-quota
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req ResetSubscriptionQuotaRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Daily && !req.Weekly {
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
return
}
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
}
// Revoke handles revoking a subscription
// DELETE /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) Revoke(c *gin.Context) {

View File

@@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
email = linuxDoSyntheticEmail(subject)
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
if tokenErr != nil {
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
return
}
fragment := url.Values{}
fragment.Set("error", "invitation_required")
fragment.Set("pending_oauth_token", pendingToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
return
}
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
@@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
redirectWithFragment(c, frontendCallback, fragment)
}
type completeLinuxDoOAuthRequest struct {
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
// the invitation code and creating the user account.
// POST /api/v1/auth/oauth/linuxdo/complete-registration
func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
var req completeLinuxDoOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
return
}
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)

View File

@@ -168,6 +168,19 @@ type RectifierSettings struct {
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
}
// BetaPolicyRule Beta 策略规则 DTO
type BetaPolicyRule struct {
BetaToken string `json:"beta_token"`
Action string `json:"action"`
Scope string `json:"scope"`
ErrorMessage string `json:"error_message,omitempty"`
}
// BetaPolicySettings Beta 策略配置 DTO
type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {

View File

@@ -652,6 +652,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc()
}
if err != nil {
// Beta policy block: return 400 immediately, no failover
var betaBlockedErr *service.BetaBlockedError
if errors.As(err, &betaBlockedErr) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
return
}
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
reqLog.Warn("gateway.prompt_too_long_from_antigravity",

View File

@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
return result, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper()

View File

@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
return nil
}
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {

View File

@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
return nil
}
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@@ -0,0 +1,290 @@
package handler
import (
"context"
"errors"
"net/http"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API requests.
// POST /v1/chat/completions
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
for {
c.Set("openai_chat_completions_fallback_model", "")
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_chat_completions.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
}
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
} else {
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
}
return
}
}
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
_ = scheduleDecision
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_chat_completions.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}

View File

@@ -31,6 +31,7 @@ const (
const (
opsErrorLogTimeout = 5 * time.Second
opsErrorLogDrainTimeout = 10 * time.Second
opsErrorLogBatchWindow = 200 * time.Millisecond
opsErrorLogMinWorkerCount = 4
opsErrorLogMaxWorkerCount = 32
@@ -38,6 +39,7 @@ const (
opsErrorLogQueueSizePerWorker = 128
opsErrorLogMinQueueSize = 256
opsErrorLogMaxQueueSize = 8192
opsErrorLogBatchSize = 32
)
type opsErrorLogJob struct {
@@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() {
for i := 0; i < workerCount; i++ {
go func() {
defer opsErrorLogWorkersWg.Done()
for job := range opsErrorLogQueue {
opsErrorLogQueueLen.Add(-1)
if job.ops == nil || job.entry == nil {
continue
for {
job, ok := <-opsErrorLogQueue
if !ok {
return
}
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
opsErrorLogQueueLen.Add(-1)
batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize)
batch = append(batch, job)
timer := time.NewTimer(opsErrorLogBatchWindow)
batchLoop:
for len(batch) < opsErrorLogBatchSize {
select {
case nextJob, ok := <-opsErrorLogQueue:
if !ok {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
flushOpsErrorLogBatch(batch)
return
}
}()
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry, nil)
cancel()
opsErrorLogProcessed.Add(1)
}()
opsErrorLogQueueLen.Add(-1)
batch = append(batch, nextJob)
case <-timer.C:
break batchLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
flushOpsErrorLogBatch(batch)
}
}()
}
}
func flushOpsErrorLogBatch(batch []opsErrorLogJob) {
if len(batch) == 0 {
return
}
defer func() {
if r := recover(); r != nil {
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
}
}()
grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch))
var processed int64
for _, job := range batch {
if job.ops == nil || job.entry == nil {
continue
}
grouped[job.ops] = append(grouped[job.ops], job.entry)
processed++
}
if processed == 0 {
return
}
for opsSvc, entries := range grouped {
if opsSvc == nil || len(entries) == 0 {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = opsSvc.RecordErrorBatch(ctx, entries)
cancel()
}
opsErrorLogProcessed.Add(processed)
}
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
if ops == nil || entry == nil {
return

View File

@@ -159,6 +159,8 @@ var claudeModels = []modelDef{
// Antigravity 支持的 Gemini 模型
var geminiModels = []modelDef{
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},

View File

@@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
requiredIDs := []string{
"claude-opus-4-6-thinking",
"gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview",
"gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview",
"gemini-3-pro-image", // legacy compatibility

View File

@@ -0,0 +1,733 @@
package apicompat
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ChatCompletionsToResponses tests
// ---------------------------------------------------------------------------
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hello"`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.Equal(t, "gpt-4o", resp.Model)
assert.True(t, resp.Stream) // always forced true
assert.False(t, *resp.Store)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
assert.Equal(t, "user", items[0].Role)
}
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
{Role: "user", Content: json.RawMessage(`"Hi"`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
assert.Equal(t, "system", items[0].Role)
assert.Equal(t, "user", items[1].Role)
}
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
{
Role: "assistant",
ToolCalls: []ChatToolCall{
{
ID: "call_1",
Type: "function",
Function: ChatFunctionCall{
Name: "ping",
Arguments: `{"host":"example.com"}`,
},
},
},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: json.RawMessage(`"pong"`),
},
},
Tools: []ChatTool{
{
Type: "function",
Function: &ChatFunction{
Name: "ping",
Description: "Ping a host",
Parameters: json.RawMessage(`{"type":"object"}`),
},
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
// user + function_call + function_call_output = 3
// (assistant message with empty content + tool_calls → only function_call items emitted)
require.Len(t, items, 3)
// Check function_call item
assert.Equal(t, "function_call", items[1].Type)
assert.Equal(t, "call_1", items[1].CallID)
assert.Equal(t, "ping", items[1].Name)
// Check function_call_output item
assert.Equal(t, "function_call_output", items[2].Type)
assert.Equal(t, "call_1", items[2].CallID)
assert.Equal(t, "pong", items[2].Output)
// Check tools
require.Len(t, resp.Tools, 1)
assert.Equal(t, "function", resp.Tools[0].Type)
assert.Equal(t, "ping", resp.Tools[0].Name)
}
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
t.Run("max_tokens", func(t *testing.T) {
maxTokens := 100
req := &ChatCompletionsRequest{
Model: "gpt-4o",
MaxTokens: &maxTokens,
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.MaxOutputTokens)
// Below minMaxOutputTokens (128), should be clamped
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
})
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
maxTokens := 100
maxCompletion := 500
req := &ChatCompletionsRequest{
Model: "gpt-4o",
MaxTokens: &maxTokens,
MaxCompletionTokens: &maxCompletion,
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.MaxOutputTokens)
assert.Equal(t, 500, *resp.MaxOutputTokens)
})
}
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
ReasoningEffort: "high",
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary)
}
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(content)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 2)
assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "Describe this", parts[0].Text)
assert.Equal(t, "input_image", parts[1].Type)
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
}
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
},
Functions: []ChatFunction{
{
Name: "get_weather",
Description: "Get weather",
Parameters: json.RawMessage(`{"type":"object"}`),
},
},
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.Len(t, resp.Tools, 1)
assert.Equal(t, "function", resp.Tools[0].Type)
assert.Equal(t, "get_weather", resp.Tools[0].Name)
// tool_choice should be converted
require.NotNil(t, resp.ToolChoice)
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
}
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
ServiceTier: "flex",
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.Equal(t, "flex", resp.ServiceTier)
}
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Do something"`)},
{
Role: "assistant",
Content: json.RawMessage(`"Let me call a function."`),
ToolCalls: []ChatToolCall{
{
ID: "call_abc",
Type: "function",
Function: ChatFunctionCall{
Name: "do_thing",
Arguments: `{}`,
},
},
},
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
// user + assistant message (with text) + function_call
require.Len(t, items, 3)
assert.Equal(t, "user", items[0].Role)
assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type)
}
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions tests
// ---------------------------------------------------------------------------
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_123",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Hello, world!"},
},
},
},
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
assert.Equal(t, "chat.completion", chat.Object)
assert.Equal(t, "gpt-4o", chat.Model)
require.Len(t, chat.Choices, 1)
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
assert.Equal(t, "Hello, world!", content)
require.NotNil(t, chat.Usage)
assert.Equal(t, 10, chat.Usage.PromptTokens)
assert.Equal(t, 5, chat.Usage.CompletionTokens)
assert.Equal(t, 15, chat.Usage.TotalTokens)
}
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_xyz",
Name: "get_weather",
Arguments: `{"city":"NYC"}`,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
msg := chat.Choices[0].Message
require.Len(t, msg.ToolCalls, 1)
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
assert.Equal(t, "function", msg.ToolCalls[0].Type)
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
}
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_789",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "reasoning",
Summary: []ResponsesSummary{
{Type: "summary_text", Text: "I thought about it."},
},
},
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "The answer is 42."},
},
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
// Reasoning summary is prepended to text
assert.Equal(t, "I thought about it.The answer is 42.", content)
}
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_inc",
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "partial..."},
},
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
assert.Equal(t, "length", chat.Choices[0].FinishReason)
}
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cache",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
},
},
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 10,
TotalTokens: 110,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 80,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.NotNil(t, chat.Usage)
require.NotNil(t, chat.Usage.PromptTokensDetails)
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_ws",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "web_search_call",
Action: &WebSearchAction{Type: "search", Query: "test"},
},
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
assert.Equal(t, "search results", content)
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesEventToChatChunks tests
// ---------------------------------------------------------------------------
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
// response.created → role chunk
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{
ID: "resp_stream",
},
}, state)
require.Len(t, chunks, 1)
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
assert.True(t, state.SentRole)
// response.output_text.delta → content chunk
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: "Hello",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
}
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 1,
Item: &ResponsesOutput{
Type: "function_call",
CallID: "call_1",
Name: "get_weather",
},
}, state)
require.Len(t, chunks, 1)
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
assert.Equal(t, "call_1", tc.ID)
assert.Equal(t, "get_weather", tc.Function.Name)
require.NotNil(t, tc.Index)
assert.Equal(t, 0, *tc.Index)
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 1, // matches the output_index from output_item.added above
Delta: `{"city":`,
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
assert.Equal(t, `{"city":`, tc.Function.Arguments)
// Add a second function call at output_index=2
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 2,
Item: &ResponsesOutput{
Type: "function_call",
CallID: "call_2",
Name: "get_time",
},
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
// Argument delta for second tool call
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 2,
Delta: `{"tz":"UTC"}`,
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
// Argument delta for first tool call (interleaved)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 1,
Delta: `"Tokyo"}`,
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
}
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 50,
OutputTokens: 20,
TotalTokens: 70,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 30,
},
},
},
}, state)
// finish chunk + usage chunk
require.Len(t, chunks, 2)
// First chunk: finish_reason
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
// Second chunk: usage
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SawToolCall = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
},
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
}
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
Delta: "Thinking...",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
}
func TestFinalizeResponsesChatStream(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
state.Usage = &ChatUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
chunks := FinalizeResponsesChatStream(state)
require.Len(t, chunks, 2)
// Finish chunk
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
// Usage chunk
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
// Idempotent: second call returns nil
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
// must be a no-op (prevents double finish_reason being sent to the client).
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
// Simulate response.completed
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
},
}, state)
require.NotEmpty(t, chunks) // finish + usage chunks
// Now FinalizeResponsesChatStream should return nil — already finalized.
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestChatChunkToSSE(t *testing.T) {
chunk := ChatCompletionsChunk{
ID: "chatcmpl-test",
Object: "chat.completion.chunk",
Created: 1700000000,
Model: "gpt-4o",
Choices: []ChatChunkChoice{
{
Index: 0,
Delta: ChatDelta{Role: "assistant"},
FinishReason: nil,
},
},
}
sse, err := ChatChunkToSSE(chunk)
require.NoError(t, err)
assert.Contains(t, sse, "data: ")
assert.Contains(t, sse, "chatcmpl-test")
assert.Contains(t, sse, "assistant")
assert.True(t, len(sse) > 10)
}
// ---------------------------------------------------------------------------
// Stream round-trip test
// ---------------------------------------------------------------------------
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
// Verify that the streaming state machine produces correct chat completions chunks.
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
var allChunks []ChatCompletionsChunk
// 1. response.created
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_rt"},
}, state)
allChunks = append(allChunks, chunks...)
// 2. text deltas
for _, text := range []string{"Hello", ", ", "world", "!"} {
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: text,
}, state)
allChunks = append(allChunks, chunks...)
}
// 3. response.completed
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 4,
TotalTokens: 14,
},
},
}, state)
allChunks = append(allChunks, chunks...)
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
require.Len(t, allChunks, 7)
// First chunk has role
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
// Text chunks
var fullText string
for i := 1; i <= 4; i++ {
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
fullText += *allChunks[i].Choices[0].Delta.Content
}
assert.Equal(t, "Hello, world!", fullText)
// Finish chunk
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
// Usage chunk
require.NotNil(t, allChunks[6].Usage)
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
// All chunks share the same ID
for _, c := range allChunks {
assert.Equal(t, "resp_rt", c.ID)
}
}

View File

@@ -0,0 +1,312 @@
package apicompat
import (
"encoding/json"
"fmt"
)
// ChatCompletionsToResponses converts a Chat Completions request into a
// Responses API request. The upstream always streams, so Stream is forced to
// true. store is always false and reasoning.encrypted_content is always
// included so that the response translator has full context.
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
input, err := convertChatMessagesToResponsesInput(req.Messages)
if err != nil {
return nil, err
}
inputJSON, err := json.Marshal(input)
if err != nil {
return nil, err
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
storeFalse := false
out.Store = &storeFalse
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
maxTokens := 0
if req.MaxTokens != nil {
maxTokens = *req.MaxTokens
}
if req.MaxCompletionTokens != nil {
maxTokens = *req.MaxCompletionTokens
}
if maxTokens > 0 {
v := maxTokens
if v < minMaxOutputTokens {
v = minMaxOutputTokens
}
out.MaxOutputTokens = &v
}
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
if req.ReasoningEffort != "" {
out.Reasoning = &ResponsesReasoning{
Effort: req.ReasoningEffort,
Summary: "auto",
}
}
// tools[] and legacy functions[] → ResponsesTool[]
if len(req.Tools) > 0 || len(req.Functions) > 0 {
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
}
// tool_choice: already compatible format — pass through directly.
// Legacy function_call needs mapping.
if len(req.ToolChoice) > 0 {
out.ToolChoice = req.ToolChoice
} else if len(req.FunctionCall) > 0 {
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
if err != nil {
return nil, fmt.Errorf("convert function_call: %w", err)
}
out.ToolChoice = tc
}
return out, nil
}
// convertChatMessagesToResponsesInput converts the Chat Completions messages
// array into a Responses API input items array.
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
var out []ResponsesInputItem
for _, m := range msgs {
items, err := chatMessageToResponsesItems(m)
if err != nil {
return nil, err
}
out = append(out, items...)
}
return out, nil
}
// chatMessageToResponsesItems converts a single ChatMessage into one or more
// ResponsesInputItem values.
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
switch m.Role {
case "system":
return chatSystemToResponses(m)
case "user":
return chatUserToResponses(m)
case "assistant":
return chatAssistantToResponses(m)
case "tool":
return chatToolToResponses(m)
case "function":
return chatFunctionToResponses(m)
default:
return chatUserToResponses(m)
}
}
// chatSystemToResponses converts a system message.
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
text, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
content, err := json.Marshal(text)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
}
// chatUserToResponses converts a user message, handling both plain strings and
// multi-modal content arrays.
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Try plain string first.
var s string
if err := json.Unmarshal(m.Content, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var parts []ChatContentPart
if err := json.Unmarshal(m.Content, &parts); err != nil {
return nil, fmt.Errorf("parse user content: %w", err)
}
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
content, err := json.Marshal(responseParts)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
// chatAssistantToResponses converts an assistant message. If there is both
// text content and tool_calls, the text is emitted as an assistant message
// first, then each tool_call becomes a function_call item. If the content is
// empty/nil and there are tool_calls, only function_call items are emitted.
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
var items []ResponsesInputItem
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
var s string
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
}
}
// Emit one function_call item per tool_call.
for _, tc := range m.ToolCalls {
args := tc.Function.Arguments
if args == "" {
args = "{}"
}
items = append(items, ResponsesInputItem{
Type: "function_call",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: args,
ID: tc.ID,
})
}
return items, nil
}
// chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.ToolCallID,
Output: output,
}}, nil
}
// chatFunctionToResponses converts a legacy function result message
// (role=function) into a function_call_output item. The Name field is used as
// call_id since legacy function calls do not carry a separate call_id.
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.Name,
Output: output,
}}, nil
}
// parseChatContent returns the string value of a ChatMessage Content field.
// Content must be a JSON string. Returns "" if content is null or empty.
func parseChatContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return "", fmt.Errorf("parse content as string: %w", err)
}
return s, nil
}
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
// function definitions to Responses API tool definitions.
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
var out []ResponsesTool
for _, t := range tools {
if t.Type != "function" || t.Function == nil {
continue
}
rt := ResponsesTool{
Type: "function",
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
Strict: t.Function.Strict,
}
out = append(out, rt)
}
// Legacy functions[] are treated as function-type tools.
for _, f := range functions {
rt := ResponsesTool{
Type: "function",
Name: f.Name,
Description: f.Description,
Parameters: f.Parameters,
Strict: f.Strict,
}
out = append(out, rt)
}
return out
}
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
// Responses API tool_choice value.
//
// "auto" → "auto"
// "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Object form: {"name":"X"}
var obj struct {
Name string `json:"name"`
}
if err := json.Unmarshal(raw, &obj); err != nil {
return nil, err
}
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": obj.Name},
})
}

View File

@@ -0,0 +1,368 @@
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions converts a Responses API response into a Chat
// Completions response. Text output items are concatenated into
// choices[0].message.content; function_call items become tool_calls.
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
id := resp.ID
if id == "" {
id = generateChatCmplID()
}
out := &ChatCompletionsResponse{
ID: id,
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
}
var contentText string
var toolCalls []ChatToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, part := range item.Content {
if part.Type == "output_text" && part.Text != "" {
contentText += part.Text
}
}
case "function_call":
toolCalls = append(toolCalls, ChatToolCall{
ID: item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: item.Name,
Arguments: item.Arguments,
},
})
case "reasoning":
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text
}
}
case "web_search_call":
// silently consumed — results already incorporated into text output
}
}
msg := ChatMessage{Role: "assistant"}
if len(toolCalls) > 0 {
msg.ToolCalls = toolCalls
}
if contentText != "" {
raw, _ := json.Marshal(contentText)
msg.Content = raw
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
out.Choices = []ChatChoice{{
Index: 0,
Message: msg,
FinishReason: finishReason,
}}
if resp.Usage != nil {
usage := &ChatUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
}
}
out.Usage = usage
}
return out
}
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
switch status {
case "incomplete":
if details != nil && details.Reason == "max_output_tokens" {
return "length"
}
return "stop"
case "completed":
if len(toolCalls) > 0 {
return "tool_calls"
}
return "stop"
default:
return "stop"
}
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
// ---------------------------------------------------------------------------
// ResponsesEventToChatState tracks state for converting a sequence of Responses
// SSE events into Chat Completions SSE chunks.
type ResponsesEventToChatState struct {
ID string
Model string
Created int64
SentRole bool
SawToolCall bool
SawText bool
Finalized bool // true after finish chunk has been emitted
NextToolCallIndex int // next sequential tool_call index to assign
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
IncludeUsage bool
Usage *ChatUsage
}
// NewResponsesEventToChatState returns an initialised stream state.
func NewResponsesEventToChatState() *ResponsesEventToChatState {
return &ResponsesEventToChatState{
ID: generateChatCmplID(),
Created: time.Now().Unix(),
OutputIndexToToolIndex: make(map[int]int),
}
}
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
// or more Chat Completions chunks, updating state as it goes.
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
switch evt.Type {
case "response.created":
return resToChatHandleCreated(evt, state)
case "response.output_text.delta":
return resToChatHandleTextDelta(evt, state)
case "response.output_item.added":
return resToChatHandleOutputItemAdded(evt, state)
case "response.function_call_arguments.delta":
return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state)
case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state)
default:
return nil
}
}
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
// stream ended without a proper completion event (e.g. upstream disconnect).
// It is idempotent: if a completion event already emitted the finish chunk,
// this returns nil.
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
if state.Finalized {
return nil
}
state.Finalized = true
finishReason := "stop"
if state.SawToolCall {
finishReason = "tool_calls"
}
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
data, err := json.Marshal(chunk)
if err != nil {
return "", err
}
return fmt.Sprintf("data: %s\n\n", data), nil
}
// --- internal handlers ---
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Response != nil {
if evt.Response.ID != "" {
state.ID = evt.Response.ID
}
if state.Model == "" && evt.Response.Model != "" {
state.Model = evt.Response.Model
}
}
// Emit the role chunk.
if state.SentRole {
return nil
}
state.SentRole = true
role := "assistant"
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
}
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
state.SawText = true
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Item == nil || evt.Item.Type != "function_call" {
return nil
}
state.SawToolCall = true
idx := state.NextToolCallIndex
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
state.NextToolCallIndex++
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
ID: evt.Item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: evt.Item.Name,
},
}},
})}
}
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
if !ok {
return nil
}
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
Function: ChatFunctionCall{
Arguments: evt.Delta,
},
}},
})}
}
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
state.Finalized = true
finishReason := "stop"
if evt.Response != nil {
if evt.Response.Usage != nil {
u := evt.Response.Usage
usage := &ChatUsage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
state.Usage = usage
}
switch evt.Response.Status {
case "incomplete":
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
finishReason = "length"
}
case "completed":
if state.SawToolCall {
finishReason = "tool_calls"
}
}
} else if state.SawToolCall {
finishReason = "tool_calls"
}
var chunks []ChatCompletionsChunk
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: delta,
FinishReason: nil,
}},
}
}
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
empty := ""
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: ChatDelta{Content: &empty},
FinishReason: &finishReason,
}},
}
}
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
func generateChatCmplID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "chatcmpl-" + hex.EncodeToString(b)
}

View File

@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
SequenceNumber int `json:"sequence_number,omitempty"`
}
// ---------------------------------------------------------------------------
// OpenAI Chat Completions API types
// ---------------------------------------------------------------------------
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
// Legacy function calling (deprecated but still supported)
Functions []ChatFunction `json:"functions,omitempty"`
FunctionCall json.RawMessage `json:"function_call,omitempty"`
}
// ChatStreamOptions configures streaming behavior.
type ChatStreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
// ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
}
// ChatContentPart is a typed content part in a multi-modal message.
type ChatContentPart struct {
Type string `json:"type"` // "text" | "image_url"
Text string `json:"text,omitempty"`
ImageURL *ChatImageURL `json:"image_url,omitempty"`
}
// ChatImageURL contains the URL for an image content part.
type ChatImageURL struct {
URL string `json:"url"`
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
}
// ChatTool describes a tool available to the model.
type ChatTool struct {
Type string `json:"type"` // "function"
Function *ChatFunction `json:"function,omitempty"`
}
// ChatFunction describes a function tool definition.
type ChatFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// ChatToolCall represents a tool call made by the assistant.
// Index is only populated in streaming chunks (omitted in non-streaming responses).
type ChatToolCall struct {
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"` // "function"
Function ChatFunctionCall `json:"function"`
}
// ChatFunctionCall contains the function name and arguments.
type ChatFunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
type ChatCompletionsResponse struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChoice is a single completion choice.
type ChatChoice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
}
// ChatUsage holds token counts in Chat Completions format.
type ChatUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
}
// ChatTokenDetails provides a breakdown of token usage.
type ChatTokenDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
}
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
type ChatCompletionsChunk struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion.chunk"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChunkChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChunkChoice is a single choice in a streaming chunk.
type ChatChunkChoice struct {
Index int `json:"index"`
Delta ChatDelta `json:"delta"`
FinishReason *string `json:"finish_reason"` // pointer: null when not final
}
// ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
}
// ---------------------------------------------------------------------------
// Shared constants
// ---------------------------------------------------------------------------

View File

@@ -16,7 +16,7 @@ const (
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
// 这些 token 是客户端特有的,不应透传给上游 API。
var DroppedBetas = []string{BetaFastMode}
var DroppedBetas = []string{}
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming

View File

@@ -18,10 +18,12 @@ func DefaultModels() []Model {
return []Model{
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
}
}

View File

@@ -0,0 +1,28 @@
package gemini
import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) {
t.Parallel()
models := DefaultModels()
byName := make(map[string]Model, len(models))
for _, model := range models {
byName[model.Name] = model
}
required := []string{
"models/gemini-2.5-flash-image",
"models/gemini-3.1-flash-image",
}
for _, name := range required {
model, ok := byName[name]
if !ok {
t.Fatalf("expected fallback model %q to exist", name)
}
if len(model.SupportedGenerationMethods) == 0 {
t.Fatalf("expected fallback model %q to advertise generation methods", name)
}
}
}

View File

@@ -13,10 +13,12 @@ type Model struct {
var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.

View File

@@ -0,0 +1,23 @@
package geminicli
import "testing"
func TestDefaultModels_ContainsImageModels(t *testing.T) {
t.Parallel()
byID := make(map[string]Model, len(DefaultModels))
for _, model := range DefaultModels {
byID[model.ID] = model
}
required := []string{
"gemini-2.5-flash-image",
"gemini-3.1-flash-image",
}
for _, id := range required {
if _, ok := byID[id]; !ok {
t.Fatalf("expected curated Gemini model %q to exist", id)
}
}
}

View File

@@ -268,6 +268,7 @@ type IDTokenClaims struct {
type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"`
ChatGPTPlanType string `json:"chatgpt_plan_type"`
UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"`
}
@@ -325,12 +326,9 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims.
// 注意:当前仅解码 payload 并校验 exp未验证 JWT 签名。
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
// DecodeIDToken decodes the ID Token JWT payload without validating expiration.
// Use this for best-effort extraction (e.g., during data import) where the token may be expired.
func DecodeIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
@@ -360,6 +358,20 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
return &claims, nil
}
// ParseIDToken parses the ID Token JWT and extracts claims.
// 注意:当前仅解码 payload 并校验 exp未验证 JWT 签名。
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
claims, err := DecodeIDToken(idToken)
if err != nil {
return nil, err
}
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const clockSkewTolerance = 120 // 秒
now := time.Now().Unix()
@@ -367,7 +379,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
}
return &claims, nil
return claims, nil
}
// UserInfo represents user information extracted from ID Token claims.
@@ -375,6 +387,7 @@ type UserInfo struct {
Email string
ChatGPTAccountID string
ChatGPTUserID string
PlanType string
UserID string
OrganizationID string
Organizations []OrganizationClaim
@@ -389,6 +402,7 @@ func (c *IDTokenClaims) GetUserInfo() *UserInfo {
if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.PlanType = c.OpenAIAuth.ChatGPTPlanType
info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations

View File

@@ -16,6 +16,7 @@ import (
"encoding/json"
"errors"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -50,6 +51,18 @@ type accountRepository struct {
schedulerCache service.SchedulerCache
}
var schedulerNeutralExtraKeyPrefixes = []string{
"codex_primary_",
"codex_secondary_",
"codex_5h_",
"codex_7d_",
}
var schedulerNeutralExtraKeys = map[string]struct{}{
"codex_usage_updated_at": {},
"session_window_utilization": {},
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
@@ -1185,12 +1198,48 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
} else {
// 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照,
// 让 sticky session / GetAccount 命中缓存时也能读到最新数据,
// 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。
r.syncSchedulerAccountSnapshot(ctx, id)
}
return nil
}
func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool {
if len(updates) == 0 {
return false
}
for key := range updates {
if isSchedulerNeutralExtraKey(key) {
continue
}
return true
}
return false
}
func isSchedulerNeutralExtraKey(key string) bool {
key = strings.TrimSpace(key)
if key == "" {
return false
}
if _, ok := schedulerNeutralExtraKeys[key]; ok {
return true
}
for _, prefix := range schedulerNeutralExtraKeyPrefixes {
if strings.HasPrefix(key, prefix) {
return true
}
}
return false
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil

View File

@@ -23,6 +23,7 @@ type AccountRepoSuite struct {
type schedulerCacheRecorder struct {
setAccounts []*service.Account
accounts map[int64]*service.Account
}
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
@@ -34,11 +35,20 @@ func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service
}
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
return nil, nil
if s.accounts == nil {
return nil, nil
}
return s.accounts[accountID], nil
}
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
s.setAccounts = append(s.setAccounts, account)
if s.accounts == nil {
s.accounts = make(map[int64]*service.Account)
}
if account != nil {
s.accounts[account.ID] = account
}
return nil
}
@@ -623,6 +633,96 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
s.Require().Equal("val", got.Extra["key"])
}
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-neutral",
Platform: service.PlatformOpenAI,
Extra: map[string]any{"codex_usage_updated_at": "old"},
})
cacheRecorder := &schedulerCacheRecorder{
accounts: map[int64]*service.Account{
account.ID: {
ID: account.ID,
Platform: account.Platform,
Status: service.StatusDisabled,
Extra: map[string]any{
"codex_usage_updated_at": "old",
},
},
},
}
s.repo.schedulerCache = cacheRecorder
updates := map[string]any{
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
"codex_5h_used_percent": 88.5,
"session_window_utilization": 0.42,
}
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
var outboxCount int
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
s.Require().Zero(outboxCount)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().NotNil(cacheRecorder.accounts[account.ID])
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
}
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-codex-exhausted",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Extra: map[string]any{},
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
s.Require().NoError(err)
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
"codex_7d_reset_after_seconds": 86400,
}))
var count int
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
s.Require().NoError(err)
s.Require().Equal(0, count)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
}
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra-mixed",
Platform: service.PlatformAntigravity,
Extra: map[string]any{},
})
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
s.Require().NoError(err)
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
"mixed_scheduling": true,
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
}))
var count int
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
s.Require().NoError(err)
s.Require().Equal(1, count)
}
// --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() {

View File

@@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo
return updated.QuotaUsed, nil
}
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
// as quota_exhausted, and returns the latest quota state in one round trip.
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
query := `
UPDATE api_keys
SET
quota_used = quota_used + $1,
status = CASE
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
ELSE status
END,
updated_at = NOW()
WHERE id = $3 AND deleted_at IS NULL
RETURNING quota_used, quota, key, status
`
state := &service.APIKeyQuotaUsageState{}
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
if err == sql.ErrNoRows {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return state, nil
}
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).

View File

@@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
user := s.mustCreateUser("quota-state@test.com")
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
key.Quota = 3
key.QuotaUsed = 1
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
s.Require().NotNil(state)
s.Require().Equal(3.5, state.QuotaUsed)
s.Require().Equal(3.0, state.Quota)
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
s.Require().Equal(key.Key, state.Key)
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(3.5, got.QuotaUsed)
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient非事务隔离数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {

View File

@@ -147,17 +147,47 @@ var (
return 1
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
// KEYS[1] = 有序集合键
// ARGV[1] = TTL(秒)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, ttl)
end
return 1
`)
// startupCleanupScript 清理非当前进程前缀的槽位成员。
// KEYS 是有序集合键列表ARGV[1] 是当前进程前缀ARGV[2] 是槽位 TTL。
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key否则刷新 EXPIRE。
startupCleanupScript = redis.NewScript(`
local activePrefix = ARGV[1]
local slotTTL = tonumber(ARGV[2])
local removed = 0
for i = 1, #KEYS do
local key = KEYS[i]
local members = redis.call('ZRANGE', key, 0, -1)
for _, member in ipairs(members) do
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
removed = removed + redis.call('ZREM', key, member)
end
end
if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, slotTTL)
end
end
return removed
`)
)
type concurrencyCache struct {
@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
if activeRequestPrefix == "" {
return nil
}
// 1. 清理有序集合中非当前进程前缀的成员
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
for _, pattern := range slotPatterns {
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
return err
}
}
// 2. 删除所有等待队列计数器(重启后计数器失效)
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
for _, pattern := range waitPatterns {
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
return err
}
}
return nil
}
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
if err != nil {
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("del %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}

View File

@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
cache service.ConcurrencyCache
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
accountID := int64(901)
userID := int64(902)
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
now := time.Now().Unix()
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
redis.Z{Score: float64(now), Member: "oldproc-1"},
redis.Z{Score: float64(now), Member: "keep-1"},
).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
redis.Z{Score: float64(now), Member: "oldproc-2"},
redis.Z{Score: float64(now), Member: "keep-2"},
).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
accountID := int64(901)
userID := int64(902)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
now := float64(time.Now().Unix())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
redis.Z{Score: now, Member: "oldproc-1"},
redis.Z{Score: now, Member: "activeproc-1"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
redis.Z{Score: now, Member: "oldproc-2"},
redis.Z{Score: now, Member: "activeproc-2"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
accountID := int64(903)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
require.NoError(s.T(), err)
require.EqualValues(s.T(), 0, exists)
}

View File

@@ -16,19 +16,7 @@ type opsRepository struct {
db *sql.DB
}
func NewOpsRepository(db *sql.DB) service.OpsRepository {
return &opsRepository{db: db}
}
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if input == nil {
return 0, fmt.Errorf("nil input")
}
q := `
const insertOpsErrorLogSQL = `
INSERT INTO ops_error_logs (
request_id,
client_request_id,
@@ -70,12 +58,77 @@ INSERT INTO ops_error_logs (
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id`
)`
func NewOpsRepository(db *sql.DB) service.OpsRepository {
return &opsRepository{db: db}
}
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if input == nil {
return 0, fmt.Errorf("nil input")
}
var id int64
err := r.db.QueryRowContext(
ctx,
q,
insertOpsErrorLogSQL+" RETURNING id",
opsInsertErrorLogArgs(input)...,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL)
if err != nil {
return 0, err
}
defer func() {
_ = stmt.Close()
}()
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil {
return inserted, err
}
inserted++
}
if err = tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
return []any{
opsNullString(input.RequestID),
opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID),
@@ -114,11 +167,7 @@ INSERT INTO ops_error_logs (
input.IsRetryable,
input.RetryCount,
input.CreatedAt,
).Scan(&id)
if err != nil {
return 0, err
}
return id, nil
}
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {

View File

@@ -0,0 +1,79 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
repo := NewOpsRepository(integrationDB).(*opsRepository)
now := time.Now().UTC()
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
{
RequestID: "batch-ops-1",
ErrorPhase: "upstream",
ErrorType: "upstream_error",
Severity: "error",
StatusCode: 429,
ErrorMessage: "rate limited",
CreatedAt: now,
},
{
RequestID: "batch-ops-2",
ErrorPhase: "internal",
ErrorType: "api_error",
Severity: "error",
StatusCode: 500,
ErrorMessage: "internal error",
CreatedAt: now.Add(time.Millisecond),
},
})
require.NoError(t, err)
require.EqualValues(t, 2, inserted)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(12345)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 1, count)
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(67890)
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
require.Equal(t, 2, count)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -12,6 +13,8 @@ type schedulerOutboxRepository struct {
db *sql.DB
}
const schedulerOutboxDedupWindow = time.Second
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
return &schedulerOutboxRepository{db: db}
}
@@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str
}
payloadArg = encoded
}
_, err := exec.ExecContext(ctx, `
query := `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
VALUES ($1, $2, $3, $4)
`, eventType, accountID, groupID, payloadArg)
`
args := []any{eventType, accountID, groupID, payloadArg}
if schedulerOutboxEventSupportsDedup(eventType) {
query = `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
SELECT $1, $2, $3, $4
WHERE NOT EXISTS (
SELECT 1
FROM scheduler_outbox
WHERE event_type = $1
AND account_id IS NOT DISTINCT FROM $2
AND group_id IS NOT DISTINCT FROM $3
AND created_at >= NOW() - make_interval(secs => $5)
)
`
args = append(args, schedulerOutboxDedupWindow.Seconds())
}
_, err := exec.ExecContext(ctx, query, args...)
return err
}
func schedulerOutboxEventSupportsDedup(eventType string) bool {
switch eventType {
case service.SchedulerOutboxEventAccountChanged,
service.SchedulerOutboxEventGroupChanged,
service.SchedulerOutboxEventFullRebuild:
return true
default:
return false
}
}

View File

@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)

View File

@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,

View File

@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)

View File

@@ -264,6 +264,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
// Antigravity 默认模型映射
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
@@ -396,6 +398,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 请求整流器配置
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
@@ -451,6 +456,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}

View File

@@ -61,6 +61,12 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
}
// 公开设置(无需认证)

View File

@@ -71,15 +71,8 @@ func RegisterGatewayRoutes(
gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
},
})
})
// OpenAI Chat Completions API
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
@@ -100,6 +93,8 @@ func RegisterGatewayRoutes(
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API不带v1前缀的别名
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)

View File

@@ -45,16 +45,23 @@ const (
// TestEvent represents a SSE event for account testing
type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
ImageURL string `json:"image_url,omitempty"`
MimeType string `json:"mime_type,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
@@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
ctx := c.Request.Context()
// Get account
@@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
if account.IsGemini() {
return s.testGeminiAccountConnection(c, account, modelID)
return s.testGeminiAccountConnection(c, account, modelID, prompt)
}
if account.Platform == PlatformAntigravity {
return s.routeAntigravityTest(c, account, modelID)
return s.routeAntigravityTest(c, account, modelID, prompt)
}
if account.Platform == PlatformSora {
@@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
// Determine the model to use
@@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
c.Writer.Flush()
// Create test payload (Gemini format)
payload := createGeminiTestPayload()
payload := createGeminiTestPayload(testModelID, prompt)
// Build request based on account type
var req *http.Request
@@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string {
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID)
return s.testGeminiAccountConnection(c, account, modelID, prompt)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
return req, nil
}
// createGeminiTestPayload creates a minimal test payload for Gemini API
func createGeminiTestPayload() []byte {
// createGeminiTestPayload creates a minimal test payload for Gemini API.
// Image models use the image-generation path so the frontend can preview the returned image.
func createGeminiTestPayload(modelID string, prompt string) []byte {
if isImageGenerationModel(modelID) {
imagePrompt := strings.TrimSpace(prompt)
if imagePrompt == "" {
imagePrompt = defaultGeminiImageTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": imagePrompt},
},
},
},
"generationConfig": map[string]any{
"responseModalities": []string{"TEXT", "IMAGE"},
"imageConfig": map[string]any{
"aspectRatio": "1:1",
},
},
}
bytes, _ := json.Marshal(payload)
return bytes
}
textPrompt := strings.TrimSpace(prompt)
if textPrompt == "" {
textPrompt = defaultGeminiTextTestPrompt
}
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
{"text": "hi"},
{"text": textPrompt},
},
},
},
@@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
if text, ok := partMap["text"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
if inlineData, ok := partMap["inlineData"].(map[string]any); ok {
mimeType, _ := inlineData["mimeType"].(string)
data, _ := inlineData["data"].(string)
if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" {
s.sendEvent(c, TestEvent{
Type: "image",
ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data),
MimeType: mimeType,
})
}
}
}
}
}
@@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
finishedAt := time.Now()
body := w.Body.String()

View File

@@ -0,0 +1,59 @@
//go:build unit
package service
import (
"encoding/json"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
t.Parallel()
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
var parsed struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
ImageConfig struct {
AspectRatio string `json:"aspectRatio"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
require.NoError(t, json.Unmarshal(payload, &parsed))
require.Len(t, parsed.Contents, 1)
require.Len(t, parsed.Contents[0].Parts, 1)
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
}
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
err := svc.processGeminiStream(ctx, stream)
require.NoError(t, err)
body := recorder.Body.String()
require.Contains(t, body, "\"type\":\"content\"")
require.Contains(t, body, "\"text\":\"ok\"")
require.Contains(t, body, "\"type\":\"image\"")
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
require.Contains(t, body, "\"mime_type\":\"image/png\"")
}

View File

@@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
mergeAccountExtra(account, updates)
if resetAt != nil {
account.RateLimitResetAt = resetAt
}
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
@@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return true
}
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
if account == nil || !account.IsOAuth() {
return nil, nil
return nil, nil, nil
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
return nil, nil, fmt.Errorf("no access token available")
}
modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create openai probe request: %w", err)
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
}
req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json")
@@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("build openai probe client: %w", err)
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
updates, err := extractOpenAICodexProbeUpdates(resp)
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
if err != nil {
return nil, err
return nil, nil, err
}
if len(updates) > 0 {
go func(accountID int64, updates map[string]any) {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if len(updates) > 0 || resetAt != nil {
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
return updates, resetAt, nil
}
return nil, nil, nil
}
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
if s == nil || s.accountRepo == nil || accountID <= 0 {
return
}
if len(updates) == 0 && resetAt == nil {
return
}
go func() {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}(account.ID, updates)
return updates, nil
}
if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
}()
}
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
if resp == nil {
return nil, nil, nil
}
return nil, nil
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
baseTime := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
if len(updates) > 0 {
return updates, resetAt, nil
}
return nil, resetAt, nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
if resp == nil {
return nil, nil
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
if len(updates) > 0 {
return updates, nil
}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
return updates, err
}
func mergeAccountExtra(account *Account, updates map[string]any) {

View File

@@ -1,11 +1,36 @@
package service
import (
"context"
"net/http"
"testing"
"time"
)
type accountUsageCodexProbeRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map[string]any
rateLimitCh chan time.Time
}
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
if r.updateExtraCh != nil {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtraCh <- copied
}
return nil
}
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
if r.rateLimitCh != nil {
r.rateLimitCh <- resetAt
}
return nil
}
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
t.Parallel()
@@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
}
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if resetAt == nil {
t.Fatal("expected resetAt from exhausted codex headers")
}
}
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
}, &resetAt)
select {
case updates := <-repo.updateExtraCh:
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe extra persistence timed out")
}
select {
case got := <-repo.rateLimitCh:
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}

View File

@@ -432,6 +432,7 @@ type adminServiceImpl struct {
entClient *dbent.Client // 用于开启数据库事务
settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
}
type userGroupRateBatchReader interface {
@@ -459,6 +460,7 @@ func NewAdminService(
entClient *dbent.Client,
settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -476,6 +478,7 @@ func NewAdminService(
entClient: entClient,
settingService: settingService,
defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
}
}
@@ -1277,9 +1280,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
if group.Status != StatusActive {
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
}
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
// 订阅类型分组:用户须持有该分组的有效订阅才可绑定
if group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
if s.userSubRepo == nil {
return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured")
}
if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil {
if errors.Is(err, ErrSubscriptionNotFound) {
return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group")
}
return nil, err
}
}
gid := *groupID
@@ -1287,7 +1298,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
apiKey.Group = group
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
if group.IsExclusive {
if group.IsExclusive && !group.IsSubscriptionType() {
opCtx := ctx
var tx *dbent.Tx
if s.entClient == nil {

View File

@@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context,
return s.addGroupErr
}
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
@@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
panic("unexpected")
}
type userSubRepoStubForGroupUpdate struct {
userSubRepoNoop
getActiveSub *UserSubscription
getActiveErr error
called bool
calledUserID int64
calledGroupID int64
}
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
s.called = true
s.calledUserID = userID
s.calledGroupID = groupID
if s.getActiveErr != nil {
return nil, s.getActiveErr
}
if s.getActiveSub == nil {
return nil, ErrSubscriptionNotFound
}
clone := *s.getActiveSub
return &clone, nil
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
@@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
// 无有效订阅时应拒绝绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
require.True(t, userSubRepo.called)
require.Equal(t, int64(42), userSubRepo.calledUserID)
require.Equal(t, int64(10), userSubRepo.calledGroupID)
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 订阅类型分组应被阻止绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.True(t, userSubRepo.called)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.False(t, userRepo.addGroupCalled)
}

View File

@@ -2164,6 +2164,112 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
}
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
signatureCheckBody := respBody
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
signatureCheckBody = unwrapped
}
if resp.StatusCode == http.StatusBadRequest &&
s.settingService != nil &&
s.settingService.IsSignatureRectifierEnabled(ctx) &&
isSignatureRelatedError(signatureCheckBody) &&
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
if wrapErr == nil {
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: upstreamAction,
body: retryWrappedBody,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0,
sessionHash: "",
})
if retryErr == nil {
retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
resp = retryResp
} else {
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
retryOpsBody := retryRespBody
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
retryOpsBody = retryUnwrapped
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry",
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
Detail: s.getUpstreamErrorDetail(retryOpsBody),
})
respBody = retryRespBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
}
contentType = resp.Header.Get("Content-Type")
}
} else {
if switchErr, ok := IsAntigravityAccountSwitchError(retryErr); ok {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: http.StatusServiceUnavailable,
Kind: "failover",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
}
} else {
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
}
}
// fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 {
goto handleSuccess

View File

@@ -134,6 +134,47 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
return s.resp, s.err
}
type queuedHTTPUpstreamStub struct {
responses []*http.Response
errors []error
requestBodies [][]byte
callCount int
onCall func(*http.Request, *queuedHTTPUpstreamStub)
}
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
if req != nil && req.Body != nil {
body, _ := io.ReadAll(req.Body)
s.requestBodies = append(s.requestBodies, body)
req.Body = io.NopCloser(bytes.NewReader(body))
} else {
s.requestBodies = append(s.requestBodies, nil)
}
idx := s.callCount
s.callCount++
if s.onCall != nil {
s.onCall(req, s)
}
var resp *http.Response
if idx < len(s.responses) {
resp = s.responses[idx]
}
var err error
if idx < len(s.errors) {
err = s.errors[idx]
}
if resp == nil && err == nil {
return nil, errors.New("unexpected upstream call")
}
return resp, err
}
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, concurrency)
}
type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
@@ -556,6 +597,177 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
require.Equal(t, mappedModel, result.Model)
}
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req-sig-2"},
},
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 7,
Name: "acc-gemini-signature",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0])
secondReq := string(upstream.requestBodies[1])
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.NotEmpty(t, events)
require.Equal(t, "signature_error", events[0].Kind)
}
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 8,
Name: "acc-gemini-signature-failover",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-failover-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
},
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
if stub.callCount != 1 {
return
}
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account.Extra = map[string]any{
modelRateLimitsKey: map[string]any{
mappedModel: map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
}
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
require.Nil(t, result)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling)
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 2)
require.Equal(t, "signature_error", events[0].Kind)
require.Equal(t, "failover", events[1].Kind)
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"strconv"
"strings"
"sync"
"time"
@@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
return d.Usage7d
}
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
// It is intentionally small so repositories can return it from a single SQL statement.
type APIKeyQuotaUsageState struct {
QuotaUsed float64
Quota float64
Key string
Status string
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
@@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
return nil
}
type quotaStateReader interface {
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
}
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
s.InvalidateAuthCacheByKey(ctx, state.Key)
}
return nil
}
// Use repository to atomically increment quota_used
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
if err != nil {

View File

@@ -0,0 +1,170 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type quotaStateRepoStub struct {
quotaBaseAPIKeyRepoStub
stateCalls int
state *APIKeyQuotaUsageState
stateErr error
}
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
s.stateCalls++
if s.stateErr != nil {
return nil, s.stateErr
}
if s.state == nil {
return nil, nil
}
out := *s.state
return &out, nil
}
type quotaStateCacheStub struct {
deleteAuthKeys []string
}
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
return 0, nil
}
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
return nil
}
type quotaBaseAPIKeyRepoStub struct {
getByIDCalls int
}
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
panic("unexpected Create call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
s.getByIDCalls++
return nil, nil
}
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
panic("unexpected Update call")
}
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
repo := &quotaStateRepoStub{
state: &APIKeyQuotaUsageState{
QuotaUsed: 12,
Quota: 10,
Key: "sk-test-quota",
Status: StatusAPIKeyQuotaExhausted,
},
}
cache := &quotaStateCacheStub{}
svc := &APIKeyService{
apiKeyRepo: repo,
cache: cache,
}
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
require.NoError(t, err)
require.Equal(t, 1, repo.stateCalls)
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -21,24 +22,25 @@ import (
)
var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration")
)
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
@@ -58,6 +60,7 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
entClient *dbent.Client
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
@@ -76,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
@@ -88,6 +92,7 @@ func NewAuthService(
defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService {
return &AuthService{
entClient: entClient,
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
@@ -523,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil
}
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -552,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrRegDisabled
}
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return nil, nil, ErrOAuthInvitationRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
randomPassword, err := randomHexString(32)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err)
@@ -579,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
if s.entClient != nil && invitationRedeemCode != nil {
tx, err := s.entClient.Tx(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
return nil, nil, ErrServiceUnavailable
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.userRepo.Create(txCtx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if err := tx.Commit(); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
}
}
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -618,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return

View File

@@ -0,0 +1,146 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access tokenJWTClaims无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT模拟旧 token应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}

View File

@@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
}
return NewAuthService(
nil, // entClient
repo,
nil, // redeemRepo
nil, // refreshTokenCache

View File

@@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService(
nil, // entClient
&userRepoStub{},
nil, // redeemRepo
nil, // refreshTokenCache

View File

@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
// 启动时清理旧进程遗留槽位与等待计数
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
}
var (
@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
return "r" + strconv.FormatUint(fallback, 36)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func RequestIDPrefix() string {
return requestIDPrefix
}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
if s == nil || s.cache == nil {
return nil
}
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
}
const (
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20

View File

@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte
return c.cleanupErr
}
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return c.cleanupErr
}
type trackingConcurrencyCache struct {
stubConcurrencyCacheForTest
cleanupPrefix string
}
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
c.cleanupPrefix = prefix
return c.cleanupErr
}
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
}
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
cache := &trackingConcurrencyCache{}
svc := NewConcurrencyService(cache)
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
}
func TestAcquireAccountSlot_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)

View File

@@ -182,6 +182,13 @@ const (
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
SettingKeyRectifierSettings = "rectifier_settings"
// =========================
// Beta Policy Settings
// =========================
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// =========================
// Sora S3 存储配置
// =========================

View File

@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "DroppedBetas removes fast-mode only",
name: "DroppedBetas is empty (filtering moved to configurable beta policy)",
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
tokens: claude.DroppedBetas,
want: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
want: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
},
}
@@ -114,25 +114,23 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
// DroppedBetas is now empty — filtering moved to configurable beta policy.
// Without a policy filter set, nothing gets dropped from the static set.
drop := droppedBetaSet()
got := mergeAnthropicBetaDropping(required, incoming, drop)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,foo-beta", got)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta", got)
require.Contains(t, got, "context-1m-2025-08-07")
require.NotContains(t, got, "fast-mode-2026-02-01")
require.Contains(t, got, "fast-mode-2026-02-01")
}
func TestDroppedBetaSet(t *testing.T) {
// Base set contains DroppedBetas
// Base set contains DroppedBetas (now empty — filtering moved to configurable beta policy)
base := droppedBetaSet()
require.NotContains(t, base, claude.BetaContext1M)
require.Contains(t, base, claude.BetaFastMode)
require.Len(t, base, len(claude.DroppedBetas))
// With extra tokens
extended := droppedBetaSet(claude.BetaClaudeCode)
require.NotContains(t, extended, claude.BetaContext1M)
require.Contains(t, extended, claude.BetaFastMode)
require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1)
}

View File

@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil
}
func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users {

View File

@@ -3948,6 +3948,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
}
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
if policy.blockErr != nil {
return nil, policy.blockErr
}
filterSet := policy.filterSet
if filterSet == nil {
filterSet = map[string]struct{}{}
}
c.Set(betaPolicyFilterSetKey, filterSet)
}
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
@@ -5133,6 +5147,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
effectiveDropSet := mergeDropSets(policyFilterSet)
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta headerOAuth 账号需要包含 oauth beta
if tokenType == "oauth" {
if mimicClaudeCode {
@@ -5146,17 +5165,22 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet))
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
} else {
// Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet))
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
} else {
// API-key accounts: apply beta policy filter to strip controlled tokens
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
}
@@ -5334,6 +5358,104 @@ func stripBetaTokensWithSet(header string, drop map[string]struct{}) string {
return strings.Join(out, ",")
}
// BetaBlockedError indicates a request was blocked by a beta policy rule.
type BetaBlockedError struct {
Message string
}
func (e *BetaBlockedError) Error() string { return e.Message }
// betaPolicyResult holds the evaluated result of beta policy rules for a single request.
type betaPolicyResult struct {
blockErr *BetaBlockedError // non-nil if a block rule matched
filterSet map[string]struct{} // tokens to filter (may be nil)
}
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
if s.settingService == nil {
return betaPolicyResult{}
}
settings, err := s.settingService.GetBetaPolicySettings(ctx)
if err != nil || settings == nil {
return betaPolicyResult{}
}
isOAuth := account.IsOAuth()
var result betaPolicyResult
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
continue
}
switch rule.Action {
case BetaPolicyActionBlock:
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
msg := rule.ErrorMessage
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
result.blockErr = &BetaBlockedError{Message: msg}
}
case BetaPolicyActionFilter:
if result.filterSet == nil {
result.filterSet = make(map[string]struct{})
}
result.filterSet[rule.BetaToken] = struct{}{}
}
}
return result
}
// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens.
// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation).
func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} {
if len(policySet) == 0 && len(extra) == 0 {
return defaultDroppedBetasSet
}
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra))
for t := range defaultDroppedBetasSet {
m[t] = struct{}{}
}
for t := range policySet {
m[t] = struct{}{}
}
for _, t := range extra {
m[t] = struct{}{}
}
return m
}
// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request.
const betaPolicyFilterSetKey = "betaPolicyFilterSet"
// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available.
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
// evaluates on demand (one DB call).
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
if c != nil {
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
if fs, ok := v.(map[string]struct{}); ok {
return fs
}
}
}
return s.evaluateBetaPolicy(ctx, "", account).filterSet
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
switch scope {
case BetaPolicyScopeAll:
return true
case BetaPolicyScopeOAuth:
return isOAuth
case BetaPolicyScopeAPIKey:
return !isOAuth
default:
return true // unknown scope → match all (fail-open)
}
}
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
func droppedBetaSet(extra ...string) map[string]struct{} {
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
@@ -5370,10 +5492,7 @@ func buildBetaTokenSet(tokens []string) map[string]struct{} {
return m
}
var (
defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode)
)
var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
@@ -5879,6 +5998,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
@@ -6068,6 +6203,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
break
}
flusher.Flush()
lastDataAt = time.Now()
}
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
@@ -6101,6 +6237,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing")
continue
}
flusher.Flush()
}
}
@@ -7311,6 +7463,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeOAuthHeaderDefaults(req, false)
}
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
if mimicClaudeCode {
@@ -7318,8 +7473,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
drop := droppedBetaSet()
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" {
@@ -7329,14 +7483,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet))
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
}
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
} else {
// API-key accounts: apply beta policy filter to strip controlled tokens
if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" {
req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
}

View File

@@ -0,0 +1,75 @@
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "user",
"parts": [{"text": "hello"}]
},
{
"role": "model",
"parts": [
{"text": "thinking", "thought": true, "thoughtSignature": "sig_1"},
{"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"}
]
}
],
"cachedContent": {
"parts": [{"text": "cached", "thoughtSignature": "sig_3"}]
},
"signature": "keep_me"
}`)
cleaned := CleanGeminiNativeThoughtSignatures(input)
var got map[string]any
require.NoError(t, json.Unmarshal(cleaned, &got))
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`)
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`)
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`)
require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`)
require.Contains(t, string(cleaned), `"signature":"keep_me"`)
}
func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) {
input := []byte(`{"contents":[invalid-json]}`)
cleaned := CleanGeminiNativeThoughtSignatures(input)
require.Equal(t, input, cleaned)
}
func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) {
input := map[string]any{
"thoughtSignature": "sig_root",
"signature": "keep_signature",
"nested": []any{
map[string]any{
"thoughtSignature": "sig_nested",
"signature": "keep_nested_signature",
},
},
}
got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any)
require.True(t, ok)
require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"])
require.Equal(t, "keep_signature", got["signature"])
nested, ok := got["nested"].([]any)
require.True(t, ok)
nestedMap, ok := nested[0].(map[string]any)
require.True(t, ok)
require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"])
require.Equal(t, "keep_nested_signature", nestedMap["signature"])
}

View File

@@ -1,6 +1,7 @@
package service
import (
"fmt"
"strings"
)
@@ -146,6 +147,22 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
} else if inputStr, ok := reqBody["input"].(string); ok {
// ChatGPT codex endpoint requires input to be a list, not a string.
// Convert string input to the expected message array format.
trimmed := strings.TrimSpace(inputStr)
if trimmed != "" {
reqBody["input"] = []any{
map[string]any{
"type": "message",
"role": "user",
"content": inputStr,
},
}
} else {
reqBody["input"] = []any{}
}
result.Modified = true
}
return result
@@ -210,6 +227,29 @@ func normalizeCodexModel(model string) string {
return "gpt-5.1"
}
func SupportsVerbosity(model string) bool {
if !strings.HasPrefix(model, "gpt-") {
return true
}
var major, minor int
n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor)
if major > 5 {
return true
}
if major < 5 {
return false
}
// gpt-5
if n == 1 {
return true
}
return minor >= 3
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""

View File

@@ -249,6 +249,50 @@ func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *t
require.Equal(t, "old instructions", instructions)
}
func TestApplyCodexOAuthTransform_StringInputConvertedToArray(t *testing.T) {
reqBody := map[string]any{"model": "gpt-5.4", "input": "Hello, world!"}
result := applyCodexOAuthTransform(reqBody, false, false)
require.True(t, result.Modified)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
msg, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "message", msg["type"])
require.Equal(t, "user", msg["role"])
require.Equal(t, "Hello, world!", msg["content"])
}
func TestApplyCodexOAuthTransform_EmptyStringInputBecomesEmptyArray(t *testing.T) {
reqBody := map[string]any{"model": "gpt-5.4", "input": ""}
result := applyCodexOAuthTransform(reqBody, false, false)
require.True(t, result.Modified)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 0)
}
func TestApplyCodexOAuthTransform_WhitespaceStringInputBecomesEmptyArray(t *testing.T) {
reqBody := map[string]any{"model": "gpt-5.4", "input": " "}
result := applyCodexOAuthTransform(reqBody, false, false)
require.True(t, result.Modified)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 0)
}
func TestApplyCodexOAuthTransform_StringInputWithToolsField(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
"input": "Run the tests",
"tools": []any{map[string]any{"type": "function", "name": "bash"}},
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
}
func TestIsInstructionsEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -0,0 +1,512 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Convert to Responses and forward
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", clientStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(chunks) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
}
// Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -2,6 +2,7 @@ package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
@@ -140,12 +141,13 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
@@ -167,7 +169,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
return s.handleAnthropicErrorResponse(resp, c, account)
@@ -213,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
@@ -279,7 +238,11 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
@@ -378,7 +341,11 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
// resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult {

View File

@@ -52,6 +52,8 @@ const (
openAIWSRetryJitterRatioDefault = 0.2
openAICompactSessionSeedKey = "openai_compact_session_seed"
codexCLIVersion = "0.104.0"
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
openAICodexSnapshotPersistMinInterval = 30 * time.Second
)
// OpenAI allowed headers whitelist (for non-passthrough).
@@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct {
nonRetryableFastFallback atomic.Int64
}
type accountWriteThrottle struct {
minInterval time.Duration
mu sync.Mutex
lastByID map[int64]time.Time
}
func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle {
return &accountWriteThrottle{
minInterval: minInterval,
lastByID: make(map[int64]time.Time),
}
}
func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
if t == nil || id <= 0 || t.minInterval <= 0 {
return true
}
t.mu.Lock()
defer t.mu.Unlock()
if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval {
return false
}
t.lastByID[id] = now
if len(t.lastByID) > 4096 {
cutoff := now.Add(-4 * t.minInterval)
for accountID, writtenAt := range t.lastByID {
if writtenAt.Before(cutoff) {
delete(t.lastByID, accountID)
}
}
}
return true
}
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
@@ -289,6 +331,7 @@ type OpenAIGatewayService struct {
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
codexSnapshotThrottle *accountWriteThrottle
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -329,17 +372,25 @@ func NewOpenAIGatewayService(
nil,
"service.openai_gateway",
),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
svc.logOpenAIWSModeBootstrap()
return svc
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle
}
return defaultOpenAICodexSnapshotPersistThrottle
}
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
@@ -911,6 +962,36 @@ func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg strin
return false
}
func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool {
if upstreamStatusCode != http.StatusBadRequest {
return false
}
match := func(text string) bool {
lower := strings.ToLower(strings.TrimSpace(text))
if lower == "" {
return false
}
if strings.Contains(lower, "an error occurred while processing your request") {
return true
}
return strings.Contains(lower, "you can retry your request") &&
strings.Contains(lower, "help.openai.com") &&
strings.Contains(lower, "request id")
}
if match(upstreamMsg) {
return true
}
if len(upstreamBody) == 0 {
return false
}
if match(gjson.GetBytes(upstreamBody, "error.message").String()) {
return true
}
return match(string(upstreamBody))
}
// ExtractSessionID extracts the raw session ID from headers or body without hashing.
// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache.
func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string {
@@ -1518,6 +1599,13 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
}
}
func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if s.shouldFailoverUpstreamError(statusCode) {
return true
}
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
@@ -1679,6 +1767,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
markPatchSet("model", normalizedModel)
}
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
// 确保高版本模型向低版本模型映射不报错
if !SupportsVerbosity(normalizedModel) {
if text, ok := reqBody["text"].(map[string]any); ok {
delete(text, "verbosity")
}
}
}
// 规范化 reasoning.effort 参数minimal -> none与上游允许值对齐。
@@ -2016,13 +2112,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle error response
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
@@ -2046,7 +2142,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
@@ -2910,6 +3006,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// compatErrorWriter is the signature for format-specific error writers used by
// the compat paths (Chat Completions and Anthropic Messages).
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
// handleCompatErrorResponse is the shared non-failover error handler for the
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
// tracking, secondary failover) but delegates the final error write to the
// format-specific writer function.
func (s *OpenAIGatewayService) handleCompatErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
writeError compatErrorWriter,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes — if the account does not handle this status,
// return a generic error without exposing upstream details.
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Map status code to error type and write response
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage
@@ -4013,11 +4223,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if len(updates) == 0 && resetAt == nil {
return
}
shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now)
if !shouldPersistUpdates && resetAt == nil {
return
}
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if len(updates) > 0 {
if shouldPersistUpdates {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {

View File

@@ -211,6 +211,26 @@ func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T)
require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required已记录请求详情用于排查"))
}
func TestIsOpenAITransientProcessingError(t *testing.T) {
require.True(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"An error occurred while processing your request.",
nil,
))
require.True(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"",
[]byte(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message."}}`),
))
require.False(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"Missing required parameter: 'instructions'",
[]byte(`{"error":{"message":"Missing required parameter: 'instructions'"}}`),
))
}
func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
@@ -264,3 +284,51 @@ func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing
require.True(t, logSink.ContainsField("request_body_size"))
require.False(t, logSink.ContainsField("request_body_preview"))
}
func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-processing-400"},
},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"An error occurred while processing your request. You can retry your request, or contact us through our help center at help.openai.com if the error persists. Please include the request ID req_123 in your message.","type":"invalid_request_error"}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{ForceCodexCLI: false},
},
httpUpstream: upstream,
}
account := &Account{
ID: 1001,
Name: "codex max套餐",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}]}`)
_, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应")
}

View File

@@ -130,6 +130,7 @@ type OpenAITokenInfo struct {
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
@@ -202,6 +203,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
tokenInfo.OrganizationID = userInfo.OrganizationID
tokenInfo.PlanType = userInfo.PlanType
}
return tokenInfo, nil
@@ -246,6 +248,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
tokenInfo.OrganizationID = userInfo.OrganizationID
tokenInfo.PlanType = userInfo.PlanType
}
return tokenInfo, nil
@@ -510,6 +513,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
if tokenInfo.OrganizationID != "" {
creds["organization_id"] = tokenInfo.OrganizationID
}
if tokenInfo.PlanType != "" {
creds["plan_type"] = tokenInfo.PlanType
}
if strings.TrimSpace(tokenInfo.ClientID) != "" {
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
}

View File

@@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN
}
}
func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) {
repo := &openAICodexSnapshotAsyncRepo{
updateExtraCh: make(chan map[string]any, 2),
rateLimitCh: make(chan time.Time, 2),
}
svc := &OpenAIGatewayService{
accountRepo: repo,
codexSnapshotThrottle: newAccountWriteThrottle(time.Hour),
}
snapshot := &OpenAICodexUsageSnapshot{
PrimaryUsedPercent: ptrFloat64WS(94),
PrimaryResetAfterSeconds: ptrIntWS(3600),
PrimaryWindowMinutes: ptrIntWS(10080),
SecondaryUsedPercent: ptrFloat64WS(22),
SecondaryResetAfterSeconds: ptrIntWS(1200),
SecondaryWindowMinutes: ptrIntWS(300),
}
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot)
select {
case <-repo.updateExtraCh:
case <-time.After(2 * time.Second):
t.Fatal("等待第一次 codex 快照落库超时")
}
select {
case updates := <-repo.updateExtraCh:
t.Fatalf("unexpected second codex snapshot write: %v", updates)
case <-time.After(200 * time.Millisecond):
}
}
func ptrFloat64WS(v float64) *float64 { return &v }
func ptrIntWS(v int) *int { return &v }

View File

@@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})), true
case "group_rate_limit_ratio":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
if availability.Group == nil || availability.Group.TotalAccounts <= 0 {
return 0, true
}
return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true
case "account_error_ratio":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
total := int64(len(availability.Accounts))
if total <= 0 {
return 0, true
}
errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})
return (float64(errorCount) / float64(total)) * 100, true
case "overload_account_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.IsOverloaded
})), true
}
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{

View File

@@ -7,6 +7,7 @@ import (
type OpsRepository interface {
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)

View File

@@ -7,6 +7,8 @@ import (
// opsRepoMock is a test-only OpsRepository implementation with optional function hooks.
type opsRepoMock struct {
InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error)
BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error)
ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error)
DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
@@ -14,9 +16,19 @@ type opsRepoMock struct {
}
func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
if m.InsertErrorLogFn != nil {
return m.InsertErrorLogFn(ctx, input)
}
return 0, nil
}
func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
if m.BatchInsertErrorLogsFn != nil {
return m.BatchInsertErrorLogsFn(ctx, inputs)
}
return int64(len(inputs)), nil
}
func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil
}

View File

@@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
}
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
if entry == nil {
prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody)
if err != nil {
log.Printf("[Ops] RecordError prepare failed: %v", err)
return err
}
if !ok {
return nil
}
if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil {
// Never bubble up to gateway; best-effort logging.
log.Printf("[Ops] RecordError failed: %v", err)
return err
}
return nil
}
func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error {
if len(entries) == 0 {
return nil
}
prepared := make([]*OpsInsertErrorLogInput, 0, len(entries))
for _, entry := range entries {
item, ok, err := s.prepareErrorLogInput(ctx, entry, nil)
if err != nil {
log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err)
continue
}
if ok {
prepared = append(prepared, item)
}
}
if len(prepared) == 0 {
return nil
}
if len(prepared) == 1 {
_, err := s.opsRepo.InsertErrorLog(ctx, prepared[0])
if err != nil {
log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err)
}
return err
}
if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil {
log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err)
var firstErr error
for _, entry := range prepared {
if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil {
log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr)
if firstErr == nil {
firstErr = insertErr
}
}
}
return firstErr
}
return nil
}
func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) {
if entry == nil {
return nil, false, nil
}
if !s.IsMonitoringEnabled(ctx) {
return nil
return nil, false, nil
}
if s.opsRepo == nil {
return nil
return nil, false, nil
}
// Ensure timestamps are always populated.
@@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
}
}
// Sanitize + serialize upstream error events list.
if len(entry.UpstreamErrors) > 0 {
const maxEvents = 32
events := entry.UpstreamErrors
if len(events) > maxEvents {
events = events[len(events)-maxEvents:]
if err := sanitizeOpsUpstreamErrors(entry); err != nil {
return nil, false, err
}
return entry, true, nil
}
func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
if entry == nil || len(entry.UpstreamErrors) == 0 {
return nil
}
const maxEvents = 32
events := entry.UpstreamErrors
if len(events) > maxEvents {
events = events[len(events)-maxEvents:]
}
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
for _, ev := range events {
if ev == nil {
continue
}
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
if out.AccountID < 0 {
out.AccountID = 0
}
if out.UpstreamStatusCode < 0 {
out.UpstreamStatusCode = 0
}
if out.AtUnixMs < 0 {
out.AtUnixMs = 0
}
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
for _, ev := range events {
if ev == nil {
continue
}
out := *ev
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
msg = truncateString(msg, 2048)
out.Message = msg
out.Platform = strings.TrimSpace(out.Platform)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
detail := strings.TrimSpace(out.Detail)
if detail != "" {
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
out.Detail = sanitizedDetail
} else {
out.Detail = ""
}
if out.AccountID < 0 {
out.AccountID = 0
}
if out.UpstreamStatusCode < 0 {
out.UpstreamStatusCode = 0
}
if out.AtUnixMs < 0 {
out.AtUnixMs = 0
}
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
msg = truncateString(msg, 2048)
out.Message = msg
detail := strings.TrimSpace(out.Detail)
if detail != "" {
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
out.Detail = sanitizedDetail
} else {
out.Detail = ""
}
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitized != "" {
out.UpstreamRequestBody = sanitized
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
out.Kind = out.Kind + ":request_body_truncated"
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitizedBody != "" {
out.UpstreamRequestBody = sanitizedBody
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
} else {
out.UpstreamRequestBody = ""
out.Kind = out.Kind + ":request_body_truncated"
}
} else {
out.UpstreamRequestBody = ""
}
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
}
evCopy := out
sanitized = append(sanitized, &evCopy)
}
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
entry.UpstreamErrors = nil
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
}
evCopy := out
sanitized = append(sanitized, &evCopy)
}
if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil {
// Never bubble up to gateway; best-effort logging.
log.Printf("[Ops] RecordError failed: %v", err)
return err
}
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
entry.UpstreamErrors = nil
return nil
}

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) {
t.Parallel()
var captured []*OpsInsertErrorLogInput
repo := &opsRepoMock{
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
captured = append(captured, inputs...)
return int64(len(inputs)), nil
},
}
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
msg := " upstream failed: https://example.com?access_token=secret-value "
detail := `{"authorization":"Bearer secret-token"}`
entries := []*OpsInsertErrorLogInput{
{
ErrorBody: `{"error":"bad","access_token":"secret"}`,
UpstreamStatusCode: intPtr(-10),
UpstreamErrorMessage: strPtr(msg),
UpstreamErrorDetail: strPtr(detail),
UpstreamErrors: []*OpsUpstreamErrorEvent{
{
AccountID: -2,
UpstreamStatusCode: 429,
Message: " token leaked ",
Detail: `{"refresh_token":"secret"}`,
UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`,
},
},
},
{
ErrorPhase: "upstream",
ErrorType: "upstream_error",
CreatedAt: time.Now().UTC(),
},
}
require.NoError(t, svc.RecordErrorBatch(context.Background(), entries))
require.Len(t, captured, 2)
first := captured[0]
require.Equal(t, "internal", first.ErrorPhase)
require.Equal(t, "api_error", first.ErrorType)
require.Nil(t, first.UpstreamStatusCode)
require.NotNil(t, first.UpstreamErrorMessage)
require.NotContains(t, *first.UpstreamErrorMessage, "secret-value")
require.Contains(t, *first.UpstreamErrorMessage, "access_token=***")
require.NotNil(t, first.UpstreamErrorDetail)
require.NotContains(t, *first.UpstreamErrorDetail, "secret-token")
require.NotContains(t, first.ErrorBody, "secret")
require.Nil(t, first.UpstreamErrors)
require.NotNil(t, first.UpstreamErrorsJSON)
require.NotContains(t, *first.UpstreamErrorsJSON, "secret")
require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]")
second := captured[1]
require.Equal(t, "upstream", second.ErrorPhase)
require.Equal(t, "upstream_error", second.ErrorType)
require.False(t, second.CreatedAt.IsZero())
}
func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) {
t.Parallel()
var (
batchCalls int
singleCalls int
)
repo := &opsRepoMock{
BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) {
batchCalls++
return 0, errors.New("batch failed")
},
InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) {
singleCalls++
return int64(singleCalls), nil
},
}
svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{
{ErrorMessage: "first"},
{ErrorMessage: "second"},
})
require.NoError(t, err)
require.Equal(t, 1, batchCalls)
require.Equal(t, 2, singleCalls)
}
func strPtr(v string) *string {
return &v
}

View File

@@ -1247,6 +1247,60 @@ func (s *SettingService) IsBudgetRectifierEnabled(ctx context.Context) bool {
return settings.Enabled && settings.ThinkingBudgetEnabled
}
// GetBetaPolicySettings 获取 Beta 策略配置
func (s *SettingService) GetBetaPolicySettings(ctx context.Context) (*BetaPolicySettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyBetaPolicySettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultBetaPolicySettings(), nil
}
return nil, fmt.Errorf("get beta policy settings: %w", err)
}
if value == "" {
return DefaultBetaPolicySettings(), nil
}
var settings BetaPolicySettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
return DefaultBetaPolicySettings(), nil
}
return &settings, nil
}
// SetBetaPolicySettings 设置 Beta 策略配置
func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *BetaPolicySettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
validActions := map[string]bool{
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
}
for i, rule := range settings.Rules {
if rule.BetaToken == "" {
return fmt.Errorf("rule[%d]: beta_token cannot be empty", i)
}
if !validActions[rule.Action] {
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
}
if !validScopes[rule.Scope] {
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
}
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal beta policy settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {

View File

@@ -191,3 +191,45 @@ func DefaultRectifierSettings() *RectifierSettings {
ThinkingBudgetEnabled: true,
}
}
// Beta Policy 策略常量
const (
BetaPolicyActionPass = "pass" // 透传,不做任何处理
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
BetaPolicyScopeAll = "all" // 所有账号类型
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
)
// BetaPolicyRule 单条 Beta 策略规则
type BetaPolicyRule struct {
BetaToken string `json:"beta_token"` // beta token 值
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
}
// BetaPolicySettings Beta 策略配置
type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置
func DefaultBetaPolicySettings() *BetaPolicySettings {
return &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "fast-mode-2026-02-01",
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
},
{
BetaToken: "context-1m-2025-08-07",
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
},
},
}
}

View File

@@ -0,0 +1,166 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage
// 其余方法继承 userSubRepoNooppanic
type resetQuotaUserSubRepoStub struct {
userSubRepoNoop
sub *UserSubscription
resetDailyCalled bool
resetWeeklyCalled bool
resetDailyErr error
resetWeeklyErr error
}
func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
if r.sub == nil || r.sub.ID != id {
return nil, ErrSubscriptionNotFound
}
cp := *r.sub
return &cp, nil
}
func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error {
r.resetDailyCalled = true
if r.resetDailyErr == nil && r.sub != nil {
r.sub.DailyUsageUSD = 0
r.sub.DailyWindowStart = &windowStart
}
return r.resetDailyErr
}
func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error {
r.resetWeeklyCalled = true
return r.resetWeeklyErr
}
func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService {
return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil)
}
func TestAdminResetQuota_ResetBoth(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 1, true, true)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
}
func TestAdminResetQuota_ResetDailyOnly(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 2, true, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage")
}
func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 3, false, true)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage")
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
}
func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20},
}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 7, false, false)
require.ErrorIs(t, err, ErrInvalidInput)
require.False(t, stub.resetDailyCalled)
require.False(t, stub.resetWeeklyCalled)
}
func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{sub: nil}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 999, true, true)
require.ErrorIs(t, err, ErrSubscriptionNotFound)
require.False(t, stub.resetDailyCalled)
require.False(t, stub.resetWeeklyCalled)
}
func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) {
dbErr := errors.New("db error")
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20},
resetDailyErr: dbErr,
}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 4, true, true)
require.ErrorIs(t, err, dbErr)
require.True(t, stub.resetDailyCalled)
require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly")
}
func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) {
dbErr := errors.New("db error")
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20},
resetWeeklyErr: dbErr,
}
svc := newResetQuotaSvc(stub)
_, err := svc.AdminResetQuota(context.Background(), 5, false, true)
require.ErrorIs(t, err, dbErr)
require.True(t, stub.resetWeeklyCalled)
}
func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) {
stub := &resetQuotaUserSubRepoStub{
sub: &UserSubscription{
ID: 6,
UserID: 10,
GroupID: 20,
DailyUsageUSD: 99.9,
},
}
svc := newResetQuotaSvc(stub)
result, err := svc.AdminResetQuota(context.Background(), 6, true, false)
require.NoError(t, err)
// ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零,
// 服务应返回第二次 GetByID 的刷新值而非初始的 99.9
require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量")
require.True(t, stub.resetDailyCalled)
}

View File

@@ -31,6 +31,7 @@ var (
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
@@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
}
// AdminResetQuota manually resets the daily and/or weekly usage windows.
// Uses startOfDay(now) as the new window start, matching automatic resets.
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
if !resetDaily && !resetWeekly {
return nil, ErrInvalidInput
}
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, err
}
windowStart := startOfDay(time.Now())
if resetDaily {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
if resetWeekly {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
// Invalidate caches, same as CheckAndResetWindows
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
// Return the refreshed subscription from DB
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间

View File

@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
}
if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}

View File

@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return nil
}
func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return nil
}
// ============================================================
// StubGatewayCache — service.GatewayCache 的空实现

View File

@@ -0,0 +1,51 @@
-- Add gemini-2.5-flash-image aliases to Antigravity model_mapping
--
-- Background:
-- Gemini native image generation now relies on gemini-2.5-flash-image, and
-- existing Antigravity accounts with persisted model_mapping need this alias in
-- order to participate in mixed scheduling from gemini groups.
--
-- Strategy:
-- Overwrite the stored model_mapping so it matches DefaultAntigravityModelMapping
-- in constants.go, including legacy gemini-3-pro-image aliases.
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -8,7 +8,7 @@
# - Creates necessary data directories
#
# After running this script, you can start services with:
# docker-compose -f docker-compose.local.yml up -d
# docker-compose up -d
# =============================================================================
set -e
@@ -65,7 +65,7 @@ main() {
fi
# Check if deployment already exists
if [ -f "docker-compose.local.yml" ] && [ -f ".env" ]; then
if [ -f "docker-compose.yml" ] && [ -f ".env" ]; then
print_warning "Deployment files already exist in current directory."
read -p "Overwrite existing files? (y/N): " -r
echo
@@ -75,17 +75,17 @@ main() {
fi
fi
# Download docker-compose.local.yml
print_info "Downloading docker-compose.local.yml..."
# Download docker-compose.local.yml and save as docker-compose.yml
print_info "Downloading docker-compose.yml..."
if command_exists curl; then
curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.local.yml
curl -sSL "${GITHUB_RAW_URL}/docker-compose.local.yml" -o docker-compose.yml
elif command_exists wget; then
wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.local.yml
wget -q "${GITHUB_RAW_URL}/docker-compose.local.yml" -O docker-compose.yml
else
print_error "Neither curl nor wget is installed. Please install one of them."
exit 1
fi
print_success "Downloaded docker-compose.local.yml"
print_success "Downloaded docker-compose.yml"
# Download .env.example
print_info "Downloading .env.example..."
@@ -144,7 +144,7 @@ main() {
print_warning "Please keep them secure and do not share publicly!"
echo ""
echo "Directory structure:"
echo " docker-compose.local.yml - Docker Compose configuration"
echo " docker-compose.yml - Docker Compose configuration"
echo " .env - Environment variables (generated secrets)"
echo " .env.example - Example template (for reference)"
echo " data/ - Application data (will be created on first run)"
@@ -154,10 +154,10 @@ main() {
echo "Next steps:"
echo " 1. (Optional) Edit .env to customize configuration"
echo " 2. Start services:"
echo " docker-compose -f docker-compose.local.yml up -d"
echo " docker-compose up -d"
echo ""
echo " 3. View logs:"
echo " docker-compose -f docker-compose.local.yml logs -f sub2api"
echo " docker-compose logs -f sub2api"
echo ""
echo " 4. Access Web UI:"
echo " http://localhost:8080"

View File

@@ -99,16 +99,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
}'
```
### 4) 购买页 URL Query 透传iframe / 新窗口一致)
当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加:
### 4) 购买页 / 自定义页面 URL Query 透传iframe / 新窗口一致)
当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加:
- `user_id`
- `token`
- `theme``light` / `dark`
- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言)
- `ui_mode`(固定 `embedded`
示例:
```text
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&ui_mode=embedded
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&lang=zh&ui_mode=embedded
```
### 5) 失败处理建议
@@ -218,16 +219,17 @@ curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
}'
```
### 4) Purchase URL query forwarding (iframe and new tab)
When Sub2API opens `purchase_subscription_url`, it appends:
### 4) Purchase / Custom Page URL query forwarding (iframe and new tab)
When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends:
- `user_id`
- `token`
- `theme` (`light` / `dark`)
- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page)
- `ui_mode` (fixed: `embedded`)
Example:
```text
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&ui_mode=embedded
https://pay.example.com/pay?user_id=123&token=<jwt>&theme=light&lang=zh&ui_mode=embedded
```
### 5) Failure handling recommendations

View File

@@ -581,6 +581,43 @@ export async function validateSoraSessionToken(
return data
}
/**
* Batch operation result type
*/
export interface BatchOperationResult {
total: number
success: number
failed: number
errors?: Array<{ account_id: number; error: string }>
warnings?: Array<{ account_id: number; warning: string }>
}
/**
* Batch clear account errors
* @param accountIds - Array of account IDs
* @returns Batch operation result
*/
export async function batchClearError(accountIds: number[]): Promise<BatchOperationResult> {
const { data } = await apiClient.post<BatchOperationResult>('/admin/accounts/batch-clear-error', {
account_ids: accountIds
})
return data
}
/**
* Batch refresh account credentials
* @param accountIds - Array of account IDs
* @returns Batch operation result
*/
export async function batchRefresh(accountIds: number[]): Promise<BatchOperationResult> {
const { data } = await apiClient.post<BatchOperationResult>('/admin/accounts/batch-refresh', {
account_ids: accountIds,
}, {
timeout: 120000 // 120s timeout for large batch refreshes
})
return data
}
export const accountsAPI = {
list,
listWithEtag,
@@ -615,7 +652,9 @@ export const accountsAPI = {
syncFromCrs,
exportData,
importData,
getAntigravityDefaultModelMapping
getAntigravityDefaultModelMapping,
batchClearError,
batchRefresh
}
export default accountsAPI

View File

@@ -308,6 +308,49 @@ export async function updateRectifierSettings(
return data
}
// ==================== Beta Policy Settings ====================
/**
* Beta policy rule interface
*/
export interface BetaPolicyRule {
beta_token: string
action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey'
error_message?: string
}
/**
* Beta policy settings interface
*/
export interface BetaPolicySettings {
rules: BetaPolicyRule[]
}
/**
* Get beta policy settings
* @returns Beta policy settings
*/
export async function getBetaPolicySettings(): Promise<BetaPolicySettings> {
const { data } = await apiClient.get<BetaPolicySettings>('/admin/settings/beta-policy')
return data
}
/**
* Update beta policy settings
* @param settings - Beta policy settings to update
* @returns Updated settings
*/
export async function updateBetaPolicySettings(
settings: BetaPolicySettings
): Promise<BetaPolicySettings> {
const { data } = await apiClient.put<BetaPolicySettings>(
'/admin/settings/beta-policy',
settings
)
return data
}
// ==================== Sora S3 Settings ====================
export interface SoraS3Settings {
@@ -456,6 +499,8 @@ export const settingsAPI = {
updateStreamTimeoutSettings,
getRectifierSettings,
updateRectifierSettings,
getBetaPolicySettings,
updateBetaPolicySettings,
getSoraS3Settings,
updateSoraS3Settings,
testSoraS3Connection,

View File

@@ -120,6 +120,23 @@ export async function revoke(id: number): Promise<{ message: string }> {
return data
}
/**
* Reset daily and/or weekly usage quota for a subscription
* @param id - Subscription ID
* @param options - Which windows to reset
* @returns Updated subscription
*/
export async function resetQuota(
id: number,
options: { daily: boolean; weekly: boolean }
): Promise<UserSubscription> {
const { data } = await apiClient.post<UserSubscription>(
`/admin/subscriptions/${id}/reset-quota`,
options
)
return data
}
/**
* List subscriptions by group
* @param groupId - Group ID
@@ -170,6 +187,7 @@ export const subscriptionsAPI = {
bulkAssign,
extend,
revoke,
resetQuota,
listByGroup,
listByUser
}

View File

@@ -335,6 +335,28 @@ export async function resetPassword(request: ResetPasswordRequest): Promise<Rese
return data
}
/**
* Complete LinuxDo OAuth registration by supplying an invitation code
* @param pendingOAuthToken - Short-lived JWT from the OAuth callback
* @param invitationCode - Invitation code entered by the user
* @returns Token pair on success
*/
export async function completeLinuxDoOAuthRegistration(
pendingOAuthToken: string,
invitationCode: string
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
const { data } = await apiClient.post<{
access_token: string
refresh_token: string
expires_in: number
token_type: string
}>('/auth/oauth/linuxdo/complete-registration', {
pending_oauth_token: pendingOAuthToken,
invitation_code: invitationCode
})
return data
}
export const authAPI = {
login,
login2FA,
@@ -357,7 +379,8 @@ export const authAPI = {
forgotPassword,
resetPassword,
refreshToken,
revokeAllSessions
revokeAllSessions,
completeLinuxDoOAuthRegistration
}
export default authAPI

View File

@@ -73,11 +73,9 @@
</div>
<!-- API Key 账号配额限制 -->
<div v-if="showDailyQuota || showWeeklyQuota || showTotalQuota" class="flex items-center gap-1">
<QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" />
<QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" />
<QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
</div>
<QuotaBadge v-if="showDailyQuota" :used="account.quota_daily_used ?? 0" :limit="account.quota_daily_limit!" label="D" />
<QuotaBadge v-if="showWeeklyQuota" :used="account.quota_weekly_used ?? 0" :limit="account.quota_weekly_limit!" label="W" />
<QuotaBadge v-if="showTotalQuota" :used="account.quota_used ?? 0" :limit="account.quota_limit!" />
</div>
</template>

Some files were not shown because too many files have changed in this diff Show More