Compare commits

..

107 Commits

Author SHA1 Message Date
shaw
25178cdbe1 fix: 修复gpt->claude同步请求返回sse的bug 2026-03-09 15:58:44 +08:00
shaw
a461538d58 fix: 修复gpt->claude转换无法命中codex缓存问题 2026-03-09 15:08:37 +08:00
shaw
ebe6f418f3 fix: gpt->claude格式转换对齐effort映射和fast 2026-03-09 11:42:35 +08:00
Wesley Liddick
391e79f8ee Merge pull request #875 from mt21625457/fix/openai-fast-billing-clean
fix(billing): 修复 OpenAI fast 档位计费并补齐展示
2026-03-09 10:32:18 +08:00
shaw
c7fcb7a84b feat: apikey限额支持查询重置时间 2026-03-09 10:22:24 +08:00
yangjianbo
87f4ed591e fix(billing): 修复 OpenAI fast 档位计费并补齐展示
- 打通 service_tier 在 OpenAI HTTP、WS、passthrough 与 usage 记录中的传递
- 修正 priority/flex 计费逻辑,并将 fast 归一化为 priority
- 在用户端和管理端补齐服务档位与计费明细展示
- 补齐前后端测试,并修复 WS 限流信号重复持久化导致的全量回归失败

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 09:51:26 +08:00
shaw
440d2e28ed fix: 恢复context-1m-2025-08-07在oauth账号中传递 2026-03-09 09:29:33 +08:00
Wesley Liddick
6cb8980404 Merge pull request #807 from touwaeriol/fix/openai-passthrough-v2
fix(openai): remove misplaced passthrough check from isModelSupportedByAccount
2026-03-09 09:06:35 +08:00
Wesley Liddick
fe752bbd35 Merge pull request #853 from touwaeriol/pr/swipe-select-admin-tables
feat(frontend): 为后台账号管理和 IP 管理增加拖筐选中能力
2026-03-09 09:04:02 +08:00
Wesley Liddick
c74d451fa2 Merge pull request #874 from touwaeriol/pr/increase-sse-max-line-size
fix: increase SSE scanner max line size from 40MB to 500MB
2026-03-09 09:03:39 +08:00
Wesley Liddick
12d743fb35 Merge pull request #868 from touwaeriol/pr/bump-antigravity-version
chore: bump Antigravity user agent version to 1.20.4
2026-03-09 09:03:25 +08:00
Wesley Liddick
6acb9f7910 Merge pull request #864 from StarryKira/fix/clear-thinking-context-management
[Fix] Fix issue #851
2026-03-09 09:02:58 +08:00
erio
eb6f5c6927 test: update UserAgent version assertion to match 1.20.4 2026-03-09 08:59:21 +08:00
erio
7ccb4c8ea3 chore: bump Antigravity user agent version to 1.20.4 2026-03-09 08:59:21 +08:00
erio
4ce986d47d fix: also update viper default max_line_size from 40MB to 500MB
The viper config default (config.go) was overriding the constant
in gateway_service.go. Both must be updated to take effect.
2026-03-09 08:56:54 +08:00
erio
91ef085d7d fix: increase SSE scanner max line size from 40MB to 500MB
4K image base64 data can exceed 40MB limit, causing "bufio.Scanner:
token too long" errors. Scanner is adaptive (starts at 64KB, grows
as needed), so increasing the cap has no impact on normal responses.
2026-03-09 08:56:54 +08:00
Wesley Liddick
97aaa24733 Merge pull request #858 from james-6-23/fix/pool-mode-03bf3485
支持 API Key 上游池模式的同账号重试次数配置与自定义错误策略
2026-03-09 08:48:53 +08:00
Wesley Liddick
faf6441633 Merge pull request #854 from james-6-23/main
feat(admin): 支持定时测试自动恢复并统一账号恢复入口
2026-03-09 08:48:36 +08:00
shaw
00c151b463 feat: gpt->claude格式转换支持图片识别 2026-03-09 08:44:09 +08:00
Wesley Liddick
a2ae9f1f27 Merge pull request #866 from touwaeriol/pr/apikey-quota-usage-bars
feat(frontend): add API Key quota progress bars to usage window
2026-03-08 21:50:24 +08:00
erio
4cd6d86426 feat(frontend): add API Key quota progress bars to usage window column
Display daily/weekly/total quota utilization as progress bars in the
usage window column for API Key accounts, providing visual feedback
consistent with other account types (OAuth/Gemini).

- Daily quota: "1d" label with reset countdown
- Weekly quota: "7d" label with reset countdown
- Total quota: "total" label (no reset)
2026-03-08 21:35:26 +08:00
时雨遥
fa72f1947a Update backend/internal/service/gateway_request_test.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-08 21:21:36 +08:00
时雨遥
9ee7d3935d Update backend/internal/service/gateway_request.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-08 21:21:28 +08:00
Elysia
1071fe0ac7 add test file 2026-03-08 21:08:09 +08:00
erio
0be003377f feat(frontend): add swipe-to-select for admin tables
Squash of all swipe-select commits for clean rebase.
2026-03-08 21:07:43 +08:00
Elysia
ca3f497b56 fix issue #851 2026-03-08 21:00:34 +08:00
Wesley Liddick
034b84b707 Merge pull request #861 from bayma888/feature/usage-user-balance-popup
feat(ui): 使用记录页面点击用户邮箱可查看用户信息和充值记录
2026-03-08 20:37:45 +08:00
Wesley Liddick
1624523c4e Merge pull request #860 from bayma888/feature/group-display-fix
feat(ui): 优化分组选择器、交互体验和样式遮挡体验问题
2026-03-08 20:37:36 +08:00
Wesley Liddick
313afe14ce Merge pull request #842 from pkssssss/fix/openai-ws-usage-refresh
fix: 修复 OpenAI WS 用量窗口刷新与限额状态不同步
2026-03-08 20:34:54 +08:00
Wesley Liddick
01180b316f Merge pull request #841 from touwaeriol/feature/account-periodic-quota
feat(account): 为 API Key 账号新增日/周周期性配额限制
2026-03-08 20:34:15 +08:00
Wesley Liddick
ee7d061001 Merge pull request #839 from pkssssss/fix/simple-mode-admin-concurrency-30
fix: 简易模式仅提升管理员默认并发到 30
2026-03-08 20:33:11 +08:00
bayma888
60c5949a74 feat(ui): 使用记录页面点击用户邮箱可查看充值记录
- UsageTable 用户邮箱改为可点击链接,点击弹出余额变动记录
- 复用 UserBalanceHistoryModal 组件,通过 getById API 获取完整用户信息
- 新增 hideActions prop 隐藏充值/退款按钮(Usage 页面仅查看)
- i18n: 新增 clickToViewBalance、failedToLoadUser 词条 (en/zh)
2026-03-08 19:11:28 +08:00
bayma888
2ebbd4c94d feat(ui): 优化分组选择器交互体验
- 分组下拉添加搜索框,支持按名称/描述快速筛选
- 新建/编辑密钥弹窗的分组选择也支持搜索
- 智能弹出方向:底部空间不足时自动向上弹出
- 倍率独立为平台配色的圆角标签,更醒目
- 分组名称加粗,名称与描述之间增加间距
- 分组选项之间添加分隔线,视觉更清晰
- 切换图标旁增加"选择分组"文字提示
- 下拉宽度自适应内容长度
- i18n: 新增 searchGroup、noGroupFound 词条 (en/zh)
2026-03-08 18:26:17 +08:00
bayma888
785115c62b fix(ui): improve group selector dropdown width and visibility
- Increase Select dropdown max-width from 320px to 480px for better content display
- Change KeysView group selector from fixed 256px to adaptive 280-480px width
- Make group switch icon always visible (60% opacity, 100% on hover)
- Allow group description to wrap to 2 lines instead of truncating
- Improve user experience for group selection in API keys page
2026-03-08 15:12:15 +08:00
kyx236
e643fc382c feat: 支持 API Key 上游池模式同账号重试次数配置与自定义错误策略 2026-03-08 14:12:17 +08:00
kyx236
34aad82ac3 fix(frontend): 修复后台页面 lint 校验问题 2026-03-08 07:34:15 +08:00
kyx236
0c29468f90 feat(admin): 支持定时测试自动恢复并统一账号恢复入口
- 为定时测试计划增加 auto_recover 配置,补齐前后端类型、接口、仓储与数据库迁移
- 在定时测试成功后自动恢复账号 error、rate-limit 等可恢复运行时状态
- 新增 /admin/accounts/:id/recover-state 接口,合并原有重置状态与清限流操作
- 更新账号管理菜单与定时测试面板,补充自动恢复开关、说明提示和状态展示
- 补充账号恢复、限流清理与仓储同步相关测试
2026-03-08 06:59:53 +08:00
神乐
9301dae63e fix: 修复 OpenAI WS 用量刷新遗漏场景 2026-03-08 04:37:20 +08:00
erio
2475d4a205 feat: add marquee selection box overlay during drag-to-select
Show a semi-transparent blue rectangle overlay while dragging to
select rows, matching the project's primary color theme with dark
mode support. The box spans the full table width from drag start
to current mouse position.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 01:45:22 +08:00
erio
be75fc3474 ci: retrigger CI checks 2026-03-08 01:42:18 +08:00
神乐
785e049af3 Merge branch 'Wei-Shaw:main' into fix/simple-mode-admin-concurrency-30 2026-03-08 00:16:59 +08:00
神乐
be4e49e6d7 Merge branch 'Wei-Shaw:main' into fix/openai-ws-usage-refresh 2026-03-08 00:16:51 +08:00
神乐
1307d604e7 fix: 补齐旧账号的 OpenAI 限流补偿 2026-03-08 00:14:15 +08:00
神乐
45d57018eb fix: 修复 OpenAI WS 限流状态与调度同步 2026-03-07 23:59:39 +08:00
shaw
03bf348530 fix(lint): gofmt formatting fixes for 3 files
Align struct field assignments and fix indentation detected by
golangci-lint v2.9's gofmt checker.
2026-03-07 23:24:09 +08:00
shaw
cab60ef735 fix(ci): downgrade golangci-lint v2.11 to v2.9 to fix lint timeout
v2.10.1 introduced a recursive markDepsForAnalyzingSource change that
causes all transitive dependencies (including 132K lines of ent
generated code) to be analyzed from source instead of compiled export
data, leading to >30 min CI timeout.

v2.9.0 is the first version with Go 1.26 support (PR #6271, merged
Feb 10 2026) and does not have this performance regression.
2026-03-07 23:10:03 +08:00
shaw
a3791104f9 feat: 支持后台设置是否启用整流开关 2026-03-07 21:55:38 +08:00
神乐
2b3e40bb2a test: 修复 PR842 的 CI 失败 2026-03-07 21:24:06 +08:00
神乐
0c1dcad429 test: 修复 PR842 的 CI 失败 2026-03-07 21:09:34 +08:00
神乐
101ef0cf62 fix: 限流账号自动退出调度并优化提示文案 2026-03-07 21:05:37 +08:00
神乐
0debe0a80c fix: 修复 OpenAI WS 用量窗口刷新与限额纠偏 2026-03-07 20:02:58 +08:00
erio
d22e62ac8a fix(test): add allow_messages_dispatch to group API contract test
The recent upstream commit added allow_messages_dispatch to the Group
DTO but did not update the API contract test expectation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 19:28:22 +08:00
erio
1ee17383f8 feat(account): add daily/weekly periodic quota limits for API Key accounts
Extend the existing total quota limit with daily and weekly periodic
dimensions. Each dimension is independently configurable and uses lazy
reset — when the period expires, usage is automatically reset to zero on
the next increment. Any dimension exceeding its limit will pause the
account from scheduling.

Backend:
- Add GetQuotaDailyLimit/Used, GetQuotaWeeklyLimit/Used, HasAnyQuotaLimit
- Rewrite IncrementQuotaUsed with atomic CTE SQL for 3-dimension update
- Rewrite ResetQuotaUsed to clear all dimensions and period timestamps
- Update postUsageBilling to use HasAnyQuotaLimit()
- Preserve daily/weekly used values on account edit

Frontend:
- Refactor QuotaLimitCard from single v-model to 3-dimension props
- Add QuotaBadge component for compact D/W/$ display
- Update AccountCapacityCell with per-dimension badges
- Update Create/Edit modals with daily/weekly quota fields
- Update AccountActionMenu hasQuotaLimit to check all dimensions
- Add i18n strings for daily/weekly/total quota labels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 19:06:59 +08:00
神乐
b59c79c458 fix: 简易模式仅提升管理员默认并发到30 2026-03-07 18:19:04 +08:00
shaw
bcb6444f89 fix(ci): update Go version check in release workflow to 1.26.1 2026-03-07 17:11:50 +08:00
Wesley Liddick
c2b14693b4 Merge pull request #835 from biubiutata/codex/fix-openai-originator-detection
fix(openai): 统一官方 Codex 客户端识别逻辑
2026-03-07 17:03:52 +08:00
shaw
92d35409de feat: 为openai分组增加messages调度开关和默认映射模型 2026-03-07 17:02:19 +08:00
shaw
351a08f813 fix: announcement强制弹窗通知补全迁移sql 2026-03-07 15:36:18 +08:00
shaw
a58dc787a9 fix(ci): 精简golangci-lint配置解决v2.11超时问题
- 移除staticcheck 470+冗余检查项,all已包含全部
- unused: generated-is-used改为true,跳过ent 13万行生成代码分析
- unused: exported-fields-are-used改为true,避免全项目导出字段引用追踪
- unused: field-writes-are-uses改为true
2026-03-07 15:17:16 +08:00
shaw
7079edc2d0 feat: announcement支持强制弹窗通知 2026-03-07 15:06:13 +08:00
admin
da89583ccc fix(openai): detect official codex client by headers 2026-03-07 14:12:38 +08:00
shaw
a42a1f08e9 fix: 编辑error状态账号时保存报Status验证失败
后端UpdateAccountRequest.Status的oneof验证缺少error状态,
前端编辑表单也未处理error状态,导致编辑异常账号时无法保存
2026-03-07 13:47:08 +08:00
shaw
ebd5253e22 fix: /response端点移除强制注入大量instructions内容 2026-03-07 13:39:47 +08:00
shaw
6411645ffc fix: 适配claude code调度openai账号的websearch功能 2026-03-07 11:33:08 +08:00
shaw
c0c322ba16 chore: openai账号模型映射新增claude->gpt快捷映射按钮 2026-03-07 10:26:30 +08:00
shaw
d35c5cd491 feat: openai平台的apikey新增claude code使用教程 2026-03-07 10:14:57 +08:00
shaw
7a353028e7 fix: 修复keys速率限制未自动重置额度的bug 2026-03-07 10:13:51 +08:00
shaw
2d8d3b7857 fix(ci): upgrade golangci-lint v2.7 to v2.11 for Go 1.26 compatibility
golangci-lint v2.7 was built with Go 1.25 and cannot lint Go 1.26
targets. v2.8+ added Go 1.26 support.
2026-03-07 08:54:05 +08:00
Wesley Liddick
4190293b07 Merge pull request #823 from StarryKira/fix/empty-stream-failover
Fix/empty streamfix issue #791
2026-03-07 08:51:07 +08:00
Wesley Liddick
421b4c0aff Merge pull request #830 from ckken/pr/ccswitch-import-improvements
fix(ccswitch): improve import provider name and usage parsing
2026-03-07 08:49:09 +08:00
Wesley Liddick
cd69a7cb85 Merge pull request #820 from geminiwen/fix/apikey-credentials-preserve-existing-fields
fix(account): preserve existing credentials when saving apikey accounts
2026-03-07 08:46:51 +08:00
shaw
0c9ba9e86c fix(security): upgrade Go 1.25.7 to 1.26.1 to resolve 4 stdlib vulnerabilities
GO-2026-4602 (os), GO-2026-4601 (net/url), GO-2026-4600 and
GO-2026-4599 (crypto/x509). The crypto/x509 fixes are only
available in go1.26.1+, not backported to go1.25.x.
2026-03-07 08:45:55 +08:00
shaw
1b4d2a41c9 fix(openai): /v1/messages端点补齐Codex用量快照提取与错误透传规则
对齐/v1/responses的Forward方法,修复两处不一致:
- 成功响应时从响应头提取OAuth账号的Codex使用量数据
- 非failover错误场景下应用管理员配置的错误透传规则
2026-03-07 08:40:07 +08:00
Wesley Liddick
0787d2b47a Merge pull request #829 from JIA-ss/fix/usage-query-rate-limit
fix(usage): 修复用量查询 429 重试风暴,增加负缓存、请求去重与随机延迟
2026-03-07 08:32:29 +08:00
ckken
97bf1d85ab feat(ccswitch): use site_name as default provider name in import link 2026-03-07 01:19:10 +08:00
ckken
207a493fab fix(ccswitch): parse remaining quota from /v1/usage response 2026-03-07 01:19:10 +08:00
JIA-ss
1f3f9e131e fix: resolve golangci-lint errors (gofmt alignment, errcheck)
- Fix gofmt: align struct field comments in UsageCache, trim trailing
  whitespace on const comments
- Fix errcheck: use comma-ok on type assertion for singleflight result

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 00:58:08 +08:00
JIA-ss
4ddedfaaf9 merge: resolve conflict with main (keep both openAI probe and usage fix)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 00:49:30 +08:00
JIA-ss
3ebebef95f fix(usage): add negative caching, singleflight, and jitter to usage queries
Prevents 429 rate-limit retry storms and reduces upstream correlation risk
for Anthropic usage API queries.

Three changes:
1. Negative caching (1 min TTL) — 429/error responses are now cached,
   preventing every subsequent page load from re-triggering failed API calls.
2. singleflight dedup — concurrent requests for the same account are
   collapsed into a single upstream call, preventing cache stampede.
3. Random jitter (0–800 ms) — staggers multi-account cache-miss bursts so
   requests from different accounts don't hit upstream simultaneously with
   identical TLS fingerprints, reducing anti-abuse correlation risk.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 00:39:16 +08:00
Gemini Wen
9f7ad47598 fix(account): clean up stale credentials fields after spreading currentCredentials
When customErrorCodes is disabled or modelMapping is empty, explicitly
delete the fields inherited from currentCredentials spread to avoid
preserving stale values.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 23:38:58 +08:00
Gemini Wen
3c83cd8be2 Merge remote-tracking branch 'origin/main' into fix/apikey-credentials-preserve-existing-fields 2026-03-06 23:38:18 +08:00
Wesley Liddick
963b3b768c Merge pull request #825 from FizzlyCode/fix/setup-token-real-utilization
fix: Setup Token 账号使用真实 utilization 值替代状态估算
2026-03-06 22:46:20 +08:00
Wesley Liddick
f6709fb5d6 Merge pull request #824 from pkssssss/fix/ws-usage-window-pr
fix(openai): 修复 WS 模式下用量窗口不显示
2026-03-06 22:45:36 +08:00
shaw
921599948b feat: /v1/messages端点适配codex账号池 2026-03-06 22:44:07 +08:00
神乐
5df3cafa99 style(go): format account usage service 2026-03-06 21:31:36 +08:00
神乐
1a2143c1fe fix(openai): adapt messages path to codex transform signature 2026-03-06 21:17:27 +08:00
神乐
dd25281305 chore(test): resolve merge conflict for ws usage window pr 2026-03-06 21:16:21 +08:00
FizzlyCode
49d0301dde fix: Setup Token 账号使用真实 utilization 值替代状态估算
从响应头 anthropic-ratelimit-unified-5h-utilization 获取并存储真实
utilization 值,解决进度条始终显示 0% 的问题。窗口重置时清除旧值,
避免残留上个窗口的数据。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 21:04:44 +08:00
神乐
d90e56eb45 chore(openai): clean up ws usage window branch 2026-03-06 21:04:24 +08:00
神乐
838ada8864 fix(openai): restore ws usage window display 2026-03-06 20:49:47 +08:00
Elysia
65a106792a fix issue #791 2026-03-06 20:37:09 +08:00
Elysia
ee4bfcbb81 Merge remote-tracking branch 'origin/main' 2026-03-06 20:37:09 +08:00
Gemini Wen
a087f089b8 fix(account): preserve existing credentials when saving apikey accounts
When editing an apikey account, the credentials object was built from
scratch, causing fields like tier_id that are not exposed in the UI to
be silently dropped on save. Spread currentCredentials first so unknown
fields are retained, then let the known fields overwrite them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 20:19:19 +08:00
Wesley Liddick
afbe8bf001 Merge pull request #809 from alfadb/feature/openai-messages
feat(openai): 添加 /v1/messages 端点和 API 兼容层
2026-03-06 20:16:06 +08:00
Wesley Liddick
2a3ef0be06 Merge pull request #818 from pkssssss/fix/remote-compact
fix(openai): support remote compact task
2026-03-06 19:41:04 +08:00
神乐
3403909354 fix(openai): support remote compact task 2026-03-06 18:51:05 +08:00
Wesley Liddick
005d0c5f53 Merge pull request #815 from mt21625457/pr/openai-user-group-rate-upstream
fix(openai): 统一专属倍率计费链路并补齐回归测试
2026-03-06 17:33:09 +08:00
Wesley Liddick
8aaaeb29cc Merge pull request #813 from FizzlyCode/fix/account-usage-display
fix: 修复账号列表五小时用量显示为 $0.00 的问题
2026-03-06 17:25:03 +08:00
yangjianbo
230f8abd04 test(openai): 修复回归测试未使用字段告警
移除订阅扣费 stub 中未被使用的状态字段与赋值,
消除 golangci-lint 的 unused 告警,保持回归测试语义不变。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 17:08:41 +08:00
yangjianbo
a18bbb5f2f fix(openai): 统一专属倍率计费链路并补齐回归测试
抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 16:47:51 +08:00
wioos
60fce4f1dc fix: 修复 lite 模式跳过窗口费用查询导致 $0.00 显示的问题
commit 80ae592c 引入 lite 模式优化首次加载性能,但将窗口费用查询也一起跳过了。
commit 491a7444 尝试用 30 秒快照缓存修复,但缓存过期后问题复现。

移除窗口费用查询的 lite/非 lite 区分,始终执行 PostgreSQL 聚合查询。
同时删除不再需要的 account_window_cost_cache.go 文件。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 16:42:12 +08:00
wioos
9af65efcdb fix: 修复 zh.ts 缺少 protocols 翻译
在 admin.proxies 下添加 protocols 对象的中文翻译,
与 en.ts 保持一致。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 16:42:12 +08:00
alfadb
bc194a7d8c fix: address PR review - Anthropic error format in panic recovery and nil guard
- Add recoverAnthropicMessagesPanic for Messages handler to return
  Anthropic-formatted errors instead of OpenAI Responses format on panic
- Add nil check for rateLimitService.HandleUpstreamError in
  ForwardAsAnthropic to match defensive pattern used elsewhere

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 15:40:15 +08:00
erio
c28f691f32 fix(openai): remove misplaced passthrough check from isModelSupportedByAccount
isModelSupportedByAccount 不被 OpenAI 调度路径调用,
OpenAI /responses 和 /chat/completions 走的是
openai_account_scheduler.go,透传短路已在 PR #806 的
第二个 commit 中正确添加到该文件。

此处的检查是多余的死代码,因为 OpenAI 账号不会走到
isModelSupportedByAccount 的这个分支。
2026-03-06 14:32:08 +08:00
alfadb
ff1f114989 feat(openai): add /v1/messages endpoint and API compatibility layer
Add Anthropic Messages API support for OpenAI platform groups, enabling
clients using Claude-style /v1/messages format to access OpenAI accounts
through automatic protocol conversion.

- Add apicompat package with type definitions and bidirectional converters
  (Anthropic ↔ Chat, Chat ↔ Responses, Anthropic ↔ Responses)
- Implement /v1/messages endpoint for OpenAI gateway with streaming support
- Add model mapping UI for OpenAI OAuth accounts (whitelist + mapping modes)
- Support prompt caching fields and codex OAuth transforms
- Fix tool call ID conversion for Responses API (fc_ prefix)
- Ensure function_call_output has non-empty output field

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 14:29:22 +08:00
Wesley Liddick
cac230206d Merge pull request #806 from touwaeriol/fix/openai-passthrough-model-check
fix(openai): passthrough accounts bypass model mapping check
2026-03-06 14:09:07 +08:00
erio
79ae15d5e8 fix: OpenAI passthrough accounts bypass model mapping check
透传模式账号仅替换认证,应允许所有模型通过。之前调度阶段的
isModelSupportedByAccount 不感知透传模式,导致 model_mapping
中未配置的新模型(如 gpt-5.4)被拒绝返回 503。
2026-03-06 14:01:47 +08:00
185 changed files with 13761 additions and 1899 deletions

View File

@@ -19,7 +19,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.25.7'
go version | grep -q 'go1.26.1'
- name: Unit tests
working-directory: backend
run: make test-unit
@@ -38,10 +38,10 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.25.7'
go version | grep -q 'go1.26.1'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
version: v2.9
args: --timeout=30m
working-directory: backend

View File

@@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
go version | grep -q 'go1.25.7'
go version | grep -q 'go1.26.1'
# Docker setup for GoReleaser
- name: Set up QEMU

View File

@@ -23,7 +23,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.25.7'
go version | grep -q 'go1.26.1'
- name: Run govulncheck
working-directory: backend
run: |

View File

@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.25.7-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn

View File

@@ -93,20 +93,13 @@ linters:
check-escaping-errors: true
staticcheck:
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
dot-import-whitelist:
- fmt
# https://staticcheck.dev/docs/configuration/options/#initialisms
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
# Default: ["200", "400", "404", "500"]
http-status-code-whitelist: [ "200", "400", "404", "500" ]
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
# Temporarily disable style checks to allow CI to pass
# "all" enables every SA/ST/S/QF check; only list the ones to disable.
checks:
- all
- -ST1000 # Package comment format
@@ -114,489 +107,19 @@ linters:
- -ST1020 # Comment on exported method format
- -ST1021 # Comment on exported type format
- -ST1022 # Comment on exported variable format
# Invalid regular expression.
# https://staticcheck.dev/docs/checks/#SA1000
- SA1000
# Invalid template.
# https://staticcheck.dev/docs/checks/#SA1001
- SA1001
# Invalid format in 'time.Parse'.
# https://staticcheck.dev/docs/checks/#SA1002
- SA1002
# Unsupported argument to functions in 'encoding/binary'.
# https://staticcheck.dev/docs/checks/#SA1003
- SA1003
# Suspiciously small untyped constant in 'time.Sleep'.
# https://staticcheck.dev/docs/checks/#SA1004
- SA1004
# Invalid first argument to 'exec.Command'.
# https://staticcheck.dev/docs/checks/#SA1005
- SA1005
# 'Printf' with dynamic first argument and no further arguments.
# https://staticcheck.dev/docs/checks/#SA1006
- SA1006
# Invalid URL in 'net/url.Parse'.
# https://staticcheck.dev/docs/checks/#SA1007
- SA1007
# Non-canonical key in 'http.Header' map.
# https://staticcheck.dev/docs/checks/#SA1008
- SA1008
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
# https://staticcheck.dev/docs/checks/#SA1010
- SA1010
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
# https://staticcheck.dev/docs/checks/#SA1011
- SA1011
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
# https://staticcheck.dev/docs/checks/#SA1012
- SA1012
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
# https://staticcheck.dev/docs/checks/#SA1013
- SA1013
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
# https://staticcheck.dev/docs/checks/#SA1014
- SA1014
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
# https://staticcheck.dev/docs/checks/#SA1015
- SA1015
# Trapping a signal that cannot be trapped.
# https://staticcheck.dev/docs/checks/#SA1016
- SA1016
# Channels used with 'os/signal.Notify' should be buffered.
# https://staticcheck.dev/docs/checks/#SA1017
- SA1017
# 'strings.Replace' called with 'n == 0', which does nothing.
# https://staticcheck.dev/docs/checks/#SA1018
- SA1018
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- SA1019
# Using an invalid host:port pair with a 'net.Listen'-related function.
# https://staticcheck.dev/docs/checks/#SA1020
- SA1020
# Using 'bytes.Equal' to compare two 'net.IP'.
# https://staticcheck.dev/docs/checks/#SA1021
- SA1021
# Modifying the buffer in an 'io.Writer' implementation.
# https://staticcheck.dev/docs/checks/#SA1023
- SA1023
# A string cutset contains duplicate characters.
# https://staticcheck.dev/docs/checks/#SA1024
- SA1024
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
# https://staticcheck.dev/docs/checks/#SA1025
- SA1025
# Cannot marshal channels or functions.
# https://staticcheck.dev/docs/checks/#SA1026
- SA1026
# Atomic access to 64-bit variable must be 64-bit aligned.
# https://staticcheck.dev/docs/checks/#SA1027
- SA1027
# 'sort.Slice' can only be used on slices.
# https://staticcheck.dev/docs/checks/#SA1028
- SA1028
# Inappropriate key in call to 'context.WithValue'.
# https://staticcheck.dev/docs/checks/#SA1029
- SA1029
# Invalid argument in call to a 'strconv' function.
# https://staticcheck.dev/docs/checks/#SA1030
- SA1030
# Overlapping byte slices passed to an encoder.
# https://staticcheck.dev/docs/checks/#SA1031
- SA1031
# Wrong order of arguments to 'errors.Is'.
# https://staticcheck.dev/docs/checks/#SA1032
- SA1032
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
# https://staticcheck.dev/docs/checks/#SA2000
- SA2000
# Empty critical section, did you mean to defer the unlock?.
# https://staticcheck.dev/docs/checks/#SA2001
- SA2001
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
# https://staticcheck.dev/docs/checks/#SA2002
- SA2002
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
# https://staticcheck.dev/docs/checks/#SA2003
- SA2003
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
# https://staticcheck.dev/docs/checks/#SA3000
- SA3000
# Assigning to 'b.N' in benchmarks distorts the results.
# https://staticcheck.dev/docs/checks/#SA3001
- SA3001
# Binary operator has identical expressions on both sides.
# https://staticcheck.dev/docs/checks/#SA4000
- SA4000
# '&*x' gets simplified to 'x', it does not copy 'x'.
# https://staticcheck.dev/docs/checks/#SA4001
- SA4001
# Comparing unsigned values against negative values is pointless.
# https://staticcheck.dev/docs/checks/#SA4003
- SA4003
# The loop exits unconditionally after one iteration.
# https://staticcheck.dev/docs/checks/#SA4004
- SA4004
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
# https://staticcheck.dev/docs/checks/#SA4005
- SA4005
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
# https://staticcheck.dev/docs/checks/#SA4006
- SA4006
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
# https://staticcheck.dev/docs/checks/#SA4008
- SA4008
# A function argument is overwritten before its first use.
# https://staticcheck.dev/docs/checks/#SA4009
- SA4009
# The result of 'append' will never be observed anywhere.
# https://staticcheck.dev/docs/checks/#SA4010
- SA4010
# Break statement with no effect. Did you mean to break out of an outer loop?.
# https://staticcheck.dev/docs/checks/#SA4011
- SA4011
# Comparing a value against NaN even though no value is equal to NaN.
# https://staticcheck.dev/docs/checks/#SA4012
- SA4012
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
# https://staticcheck.dev/docs/checks/#SA4013
- SA4013
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
# https://staticcheck.dev/docs/checks/#SA4014
- SA4014
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
# https://staticcheck.dev/docs/checks/#SA4015
- SA4015
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
# https://staticcheck.dev/docs/checks/#SA4016
- SA4016
# Discarding the return values of a function without side effects, making the call pointless.
# https://staticcheck.dev/docs/checks/#SA4017
- SA4017
# Self-assignment of variables.
# https://staticcheck.dev/docs/checks/#SA4018
- SA4018
# Multiple, identical build constraints in the same file.
# https://staticcheck.dev/docs/checks/#SA4019
- SA4019
# Unreachable case clause in a type switch.
# https://staticcheck.dev/docs/checks/#SA4020
- SA4020
# "x = append(y)" is equivalent to "x = y".
# https://staticcheck.dev/docs/checks/#SA4021
- SA4021
# Comparing the address of a variable against nil.
# https://staticcheck.dev/docs/checks/#SA4022
- SA4022
# Impossible comparison of interface value with untyped nil.
# https://staticcheck.dev/docs/checks/#SA4023
- SA4023
# Checking for impossible return value from a builtin function.
# https://staticcheck.dev/docs/checks/#SA4024
- SA4024
# Integer division of literals that results in zero.
# https://staticcheck.dev/docs/checks/#SA4025
- SA4025
# Go constants cannot express negative zero.
# https://staticcheck.dev/docs/checks/#SA4026
- SA4026
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
# https://staticcheck.dev/docs/checks/#SA4027
- SA4027
# 'x % 1' is always zero.
# https://staticcheck.dev/docs/checks/#SA4028
- SA4028
# Ineffective attempt at sorting slice.
# https://staticcheck.dev/docs/checks/#SA4029
- SA4029
# Ineffective attempt at generating random number.
# https://staticcheck.dev/docs/checks/#SA4030
- SA4030
# Checking never-nil value against nil.
# https://staticcheck.dev/docs/checks/#SA4031
- SA4031
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
# https://staticcheck.dev/docs/checks/#SA4032
- SA4032
# Assignment to nil map.
# https://staticcheck.dev/docs/checks/#SA5000
- SA5000
# Deferring 'Close' before checking for a possible error.
# https://staticcheck.dev/docs/checks/#SA5001
- SA5001
# The empty for loop ("for {}") spins and can block the scheduler.
# https://staticcheck.dev/docs/checks/#SA5002
- SA5002
# Defers in infinite loops will never execute.
# https://staticcheck.dev/docs/checks/#SA5003
- SA5003
# "for { select { ..." with an empty default branch spins.
# https://staticcheck.dev/docs/checks/#SA5004
- SA5004
# The finalizer references the finalized object, preventing garbage collection.
# https://staticcheck.dev/docs/checks/#SA5005
- SA5005
# Infinite recursive call.
# https://staticcheck.dev/docs/checks/#SA5007
- SA5007
# Invalid struct tag.
# https://staticcheck.dev/docs/checks/#SA5008
- SA5008
# Invalid Printf call.
# https://staticcheck.dev/docs/checks/#SA5009
- SA5009
# Impossible type assertion.
# https://staticcheck.dev/docs/checks/#SA5010
- SA5010
# Possible nil pointer dereference.
# https://staticcheck.dev/docs/checks/#SA5011
- SA5011
# Passing odd-sized slice to function expecting even size.
# https://staticcheck.dev/docs/checks/#SA5012
- SA5012
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
# https://staticcheck.dev/docs/checks/#SA6000
- SA6000
# Missing an optimization opportunity when indexing maps by byte slices.
# https://staticcheck.dev/docs/checks/#SA6001
- SA6001
# Storing non-pointer values in 'sync.Pool' allocates memory.
# https://staticcheck.dev/docs/checks/#SA6002
- SA6002
# Converting a string to a slice of runes before ranging over it.
# https://staticcheck.dev/docs/checks/#SA6003
- SA6003
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
# https://staticcheck.dev/docs/checks/#SA6005
- SA6005
# Using io.WriteString to write '[]byte'.
# https://staticcheck.dev/docs/checks/#SA6006
- SA6006
# Defers in range loops may not run when you expect them to.
# https://staticcheck.dev/docs/checks/#SA9001
- SA9001
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
# https://staticcheck.dev/docs/checks/#SA9002
- SA9002
# Empty body in an if or else branch.
# https://staticcheck.dev/docs/checks/#SA9003
- SA9003
# Only the first constant has an explicit type.
# https://staticcheck.dev/docs/checks/#SA9004
- SA9004
# Trying to marshal a struct with no public fields nor custom marshaling.
# https://staticcheck.dev/docs/checks/#SA9005
- SA9005
# Dubious bit shifting of a fixed size integer value.
# https://staticcheck.dev/docs/checks/#SA9006
- SA9006
# Deleting a directory that shouldn't be deleted.
# https://staticcheck.dev/docs/checks/#SA9007
- SA9007
# 'else' branch of a type assertion is probably not reading the right value.
# https://staticcheck.dev/docs/checks/#SA9008
- SA9008
# Ineffectual Go compiler directive.
# https://staticcheck.dev/docs/checks/#SA9009
- SA9009
# NOTE: ST1000, ST1001, ST1003, ST1020, ST1021, ST1022 are disabled above
# Incorrectly formatted error string.
# https://staticcheck.dev/docs/checks/#ST1005
- ST1005
# Poorly chosen receiver name.
# https://staticcheck.dev/docs/checks/#ST1006
- ST1006
# A function's error value should be its last return value.
# https://staticcheck.dev/docs/checks/#ST1008
- ST1008
# Poorly chosen name for variable of type 'time.Duration'.
# https://staticcheck.dev/docs/checks/#ST1011
- ST1011
# Poorly chosen name for error variable.
# https://staticcheck.dev/docs/checks/#ST1012
- ST1012
# Should use constants for HTTP error codes, not magic numbers.
# https://staticcheck.dev/docs/checks/#ST1013
- ST1013
# A switch's default case should be the first or last case.
# https://staticcheck.dev/docs/checks/#ST1015
- ST1015
# Use consistent method receiver names.
# https://staticcheck.dev/docs/checks/#ST1016
- ST1016
# Don't use Yoda conditions.
# https://staticcheck.dev/docs/checks/#ST1017
- ST1017
# Avoid zero-width and control characters in string literals.
# https://staticcheck.dev/docs/checks/#ST1018
- ST1018
# Importing the same package multiple times.
# https://staticcheck.dev/docs/checks/#ST1019
- ST1019
# NOTE: ST1020, ST1021, ST1022 removed (disabled above)
# Redundant type in variable declaration.
# https://staticcheck.dev/docs/checks/#ST1023
- ST1023
# Use plain channel send or receive instead of single-case select.
# https://staticcheck.dev/docs/checks/#S1000
- S1000
# Replace for loop with call to copy.
# https://staticcheck.dev/docs/checks/#S1001
- S1001
# Omit comparison with boolean constant.
# https://staticcheck.dev/docs/checks/#S1002
- S1002
# Replace call to 'strings.Index' with 'strings.Contains'.
# https://staticcheck.dev/docs/checks/#S1003
- S1003
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
# https://staticcheck.dev/docs/checks/#S1004
- S1004
# Drop unnecessary use of the blank identifier.
# https://staticcheck.dev/docs/checks/#S1005
- S1005
# Use "for { ... }" for infinite loops.
# https://staticcheck.dev/docs/checks/#S1006
- S1006
# Simplify regular expression by using raw string literal.
# https://staticcheck.dev/docs/checks/#S1007
- S1007
# Simplify returning boolean expression.
# https://staticcheck.dev/docs/checks/#S1008
- S1008
# Omit redundant nil check on slices, maps, and channels.
# https://staticcheck.dev/docs/checks/#S1009
- S1009
# Omit default slice index.
# https://staticcheck.dev/docs/checks/#S1010
- S1010
# Use a single 'append' to concatenate two slices.
# https://staticcheck.dev/docs/checks/#S1011
- S1011
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
# https://staticcheck.dev/docs/checks/#S1012
- S1012
# Use a type conversion instead of manually copying struct fields.
# https://staticcheck.dev/docs/checks/#S1016
- S1016
# Replace manual trimming with 'strings.TrimPrefix'.
# https://staticcheck.dev/docs/checks/#S1017
- S1017
# Use "copy" for sliding elements.
# https://staticcheck.dev/docs/checks/#S1018
- S1018
# Simplify "make" call by omitting redundant arguments.
# https://staticcheck.dev/docs/checks/#S1019
- S1019
# Omit redundant nil check in type assertion.
# https://staticcheck.dev/docs/checks/#S1020
- S1020
# Merge variable declaration and assignment.
# https://staticcheck.dev/docs/checks/#S1021
- S1021
# Omit redundant control flow.
# https://staticcheck.dev/docs/checks/#S1023
- S1023
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
# https://staticcheck.dev/docs/checks/#S1024
- S1024
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
# https://staticcheck.dev/docs/checks/#S1025
- S1025
# Simplify error construction with 'fmt.Errorf'.
# https://staticcheck.dev/docs/checks/#S1028
- S1028
# Range over the string directly.
# https://staticcheck.dev/docs/checks/#S1029
- S1029
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
# https://staticcheck.dev/docs/checks/#S1030
- S1030
# Omit redundant nil check around loop.
# https://staticcheck.dev/docs/checks/#S1031
- S1031
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
# https://staticcheck.dev/docs/checks/#S1032
- S1032
# Unnecessary guard around call to "delete".
# https://staticcheck.dev/docs/checks/#S1033
- S1033
# Use result of type assertion to simplify cases.
# https://staticcheck.dev/docs/checks/#S1034
- S1034
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
# https://staticcheck.dev/docs/checks/#S1035
- S1035
# Unnecessary guard around map access.
# https://staticcheck.dev/docs/checks/#S1036
- S1036
# Elaborate way of sleeping.
# https://staticcheck.dev/docs/checks/#S1037
- S1037
# Unnecessarily complex way of printing formatted string.
# https://staticcheck.dev/docs/checks/#S1038
- S1038
# Unnecessary use of 'fmt.Sprint'.
# https://staticcheck.dev/docs/checks/#S1039
- S1039
# Type assertion to current type.
# https://staticcheck.dev/docs/checks/#S1040
- S1040
# Apply De Morgan's law.
# https://staticcheck.dev/docs/checks/#QF1001
- QF1001
# Convert untagged switch to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1002
- QF1002
# Convert if/else-if chain to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1003
- QF1003
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
# https://staticcheck.dev/docs/checks/#QF1004
- QF1004
# Expand call to 'math.Pow'.
# https://staticcheck.dev/docs/checks/#QF1005
- QF1005
# Lift 'if'+'break' into loop condition.
# https://staticcheck.dev/docs/checks/#QF1006
- QF1006
# Merge conditional assignment into variable declaration.
# https://staticcheck.dev/docs/checks/#QF1007
- QF1007
# Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008
- QF1008
# Use 'time.Time.Equal' instead of '==' operator.
# https://staticcheck.dev/docs/checks/#QF1009
- QF1009
# Convert slice of bytes to string when printing it.
# https://staticcheck.dev/docs/checks/#QF1010
- QF1010
# Omit redundant type from variable declaration.
# https://staticcheck.dev/docs/checks/#QF1011
- QF1011
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
# https://staticcheck.dev/docs/checks/#QF1012
- QF1012
unused:
# Mark all struct fields that have been written to as used.
# Default: true
field-writes-are-uses: false
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
field-writes-are-uses: true
# Default: false
post-statements-are-reads: true
# Mark all exported fields as used.
# default: true
exported-fields-are-used: false
# Mark all function parameters as used.
# default: true
parameters-are-used: true
# Mark all local variables as used.
# default: true
local-variables-are-used: false
# Mark all identifiers inside generated files as used.
# Default: true
generated-is-used: false
exported-fields-are-used: true
# Default: true
parameters-are-used: true
# Default: true
local-variables-are-used: false
# Default: true — must be true, ent generates 130K+ lines of code
generated-is-used: true
formatters:
enable:

View File

@@ -162,9 +162,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@@ -229,7 +229,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
application := &Application{
Server: httpServer,

View File

@@ -25,6 +25,8 @@ type Announcement struct {
Content string `json:"content,omitempty"`
// 状态: draft, active, archived
Status string `json:"status,omitempty"`
// 通知模式: silent(仅铃铛), popup(弹窗提醒)
NotifyMode string `json:"notify_mode,omitempty"`
// 展示条件JSON 规则)
Targeting domain.AnnouncementTargeting `json:"targeting,omitempty"`
// 开始展示时间(为空表示立即生效)
@@ -72,7 +74,7 @@ func (*Announcement) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case announcement.FieldID, announcement.FieldCreatedBy, announcement.FieldUpdatedBy:
values[i] = new(sql.NullInt64)
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus:
case announcement.FieldTitle, announcement.FieldContent, announcement.FieldStatus, announcement.FieldNotifyMode:
values[i] = new(sql.NullString)
case announcement.FieldStartsAt, announcement.FieldEndsAt, announcement.FieldCreatedAt, announcement.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@@ -115,6 +117,12 @@ func (_m *Announcement) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Status = value.String
}
case announcement.FieldNotifyMode:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field notify_mode", values[i])
} else if value.Valid {
_m.NotifyMode = value.String
}
case announcement.FieldTargeting:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field targeting", values[i])
@@ -213,6 +221,9 @@ func (_m *Announcement) String() string {
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("notify_mode=")
builder.WriteString(_m.NotifyMode)
builder.WriteString(", ")
builder.WriteString("targeting=")
builder.WriteString(fmt.Sprintf("%v", _m.Targeting))
builder.WriteString(", ")

View File

@@ -20,6 +20,8 @@ const (
FieldContent = "content"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldNotifyMode holds the string denoting the notify_mode field in the database.
FieldNotifyMode = "notify_mode"
// FieldTargeting holds the string denoting the targeting field in the database.
FieldTargeting = "targeting"
// FieldStartsAt holds the string denoting the starts_at field in the database.
@@ -53,6 +55,7 @@ var Columns = []string{
FieldTitle,
FieldContent,
FieldStatus,
FieldNotifyMode,
FieldTargeting,
FieldStartsAt,
FieldEndsAt,
@@ -81,6 +84,10 @@ var (
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
StatusValidator func(string) error
// DefaultNotifyMode holds the default value on creation for the "notify_mode" field.
DefaultNotifyMode string
// NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
NotifyModeValidator func(string) error
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
@@ -112,6 +119,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByNotifyMode orders the results by the notify_mode field.
func ByNotifyMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotifyMode, opts...).ToFunc()
}
// ByStartsAt orders the results by the starts_at field.
func ByStartsAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartsAt, opts...).ToFunc()

View File

@@ -70,6 +70,11 @@ func Status(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStatus, v))
}
// NotifyMode applies equality check predicate on the "notify_mode" field. It's identical to NotifyModeEQ.
func NotifyMode(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
}
// StartsAt applies equality check predicate on the "starts_at" field. It's identical to StartsAtEQ.
func StartsAt(v time.Time) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldStartsAt, v))
@@ -295,6 +300,71 @@ func StatusContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldStatus, v))
}
// NotifyModeEQ applies the EQ predicate on the "notify_mode" field.
func NotifyModeEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEQ(FieldNotifyMode, v))
}
// NotifyModeNEQ applies the NEQ predicate on the "notify_mode" field.
func NotifyModeNEQ(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldNEQ(FieldNotifyMode, v))
}
// NotifyModeIn applies the In predicate on the "notify_mode" field.
func NotifyModeIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldIn(FieldNotifyMode, vs...))
}
// NotifyModeNotIn applies the NotIn predicate on the "notify_mode" field.
func NotifyModeNotIn(vs ...string) predicate.Announcement {
return predicate.Announcement(sql.FieldNotIn(FieldNotifyMode, vs...))
}
// NotifyModeGT applies the GT predicate on the "notify_mode" field.
func NotifyModeGT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGT(FieldNotifyMode, v))
}
// NotifyModeGTE applies the GTE predicate on the "notify_mode" field.
func NotifyModeGTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldGTE(FieldNotifyMode, v))
}
// NotifyModeLT applies the LT predicate on the "notify_mode" field.
func NotifyModeLT(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLT(FieldNotifyMode, v))
}
// NotifyModeLTE applies the LTE predicate on the "notify_mode" field.
func NotifyModeLTE(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldLTE(FieldNotifyMode, v))
}
// NotifyModeContains applies the Contains predicate on the "notify_mode" field.
func NotifyModeContains(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContains(FieldNotifyMode, v))
}
// NotifyModeHasPrefix applies the HasPrefix predicate on the "notify_mode" field.
func NotifyModeHasPrefix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasPrefix(FieldNotifyMode, v))
}
// NotifyModeHasSuffix applies the HasSuffix predicate on the "notify_mode" field.
func NotifyModeHasSuffix(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldHasSuffix(FieldNotifyMode, v))
}
// NotifyModeEqualFold applies the EqualFold predicate on the "notify_mode" field.
func NotifyModeEqualFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldEqualFold(FieldNotifyMode, v))
}
// NotifyModeContainsFold applies the ContainsFold predicate on the "notify_mode" field.
func NotifyModeContainsFold(v string) predicate.Announcement {
return predicate.Announcement(sql.FieldContainsFold(FieldNotifyMode, v))
}
// TargetingIsNil applies the IsNil predicate on the "targeting" field.
func TargetingIsNil() predicate.Announcement {
return predicate.Announcement(sql.FieldIsNull(FieldTargeting))

View File

@@ -50,6 +50,20 @@ func (_c *AnnouncementCreate) SetNillableStatus(v *string) *AnnouncementCreate {
return _c
}
// SetNotifyMode sets the "notify_mode" field.
func (_c *AnnouncementCreate) SetNotifyMode(v string) *AnnouncementCreate {
_c.mutation.SetNotifyMode(v)
return _c
}
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
func (_c *AnnouncementCreate) SetNillableNotifyMode(v *string) *AnnouncementCreate {
if v != nil {
_c.SetNotifyMode(*v)
}
return _c
}
// SetTargeting sets the "targeting" field.
func (_c *AnnouncementCreate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementCreate {
_c.mutation.SetTargeting(v)
@@ -202,6 +216,10 @@ func (_c *AnnouncementCreate) defaults() {
v := announcement.DefaultStatus
_c.mutation.SetStatus(v)
}
if _, ok := _c.mutation.NotifyMode(); !ok {
v := announcement.DefaultNotifyMode
_c.mutation.SetNotifyMode(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok {
v := announcement.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
@@ -238,6 +256,14 @@ func (_c *AnnouncementCreate) check() error {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
if _, ok := _c.mutation.NotifyMode(); !ok {
return &ValidationError{Name: "notify_mode", err: errors.New(`ent: missing required field "Announcement.notify_mode"`)}
}
if v, ok := _c.mutation.NotifyMode(); ok {
if err := announcement.NotifyModeValidator(v); err != nil {
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Announcement.created_at"`)}
}
@@ -283,6 +309,10 @@ func (_c *AnnouncementCreate) createSpec() (*Announcement, *sqlgraph.CreateSpec)
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
_node.Status = value
}
if value, ok := _c.mutation.NotifyMode(); ok {
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
_node.NotifyMode = value
}
if value, ok := _c.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
_node.Targeting = value
@@ -415,6 +445,18 @@ func (u *AnnouncementUpsert) UpdateStatus() *AnnouncementUpsert {
return u
}
// SetNotifyMode sets the "notify_mode" field.
func (u *AnnouncementUpsert) SetNotifyMode(v string) *AnnouncementUpsert {
u.Set(announcement.FieldNotifyMode, v)
return u
}
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
func (u *AnnouncementUpsert) UpdateNotifyMode() *AnnouncementUpsert {
u.SetExcluded(announcement.FieldNotifyMode)
return u
}
// SetTargeting sets the "targeting" field.
func (u *AnnouncementUpsert) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsert {
u.Set(announcement.FieldTargeting, v)
@@ -616,6 +658,20 @@ func (u *AnnouncementUpsertOne) UpdateStatus() *AnnouncementUpsertOne {
})
}
// SetNotifyMode sets the "notify_mode" field.
func (u *AnnouncementUpsertOne) SetNotifyMode(v string) *AnnouncementUpsertOne {
return u.Update(func(s *AnnouncementUpsert) {
s.SetNotifyMode(v)
})
}
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
func (u *AnnouncementUpsertOne) UpdateNotifyMode() *AnnouncementUpsertOne {
return u.Update(func(s *AnnouncementUpsert) {
s.UpdateNotifyMode()
})
}
// SetTargeting sets the "targeting" field.
func (u *AnnouncementUpsertOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertOne {
return u.Update(func(s *AnnouncementUpsert) {
@@ -1002,6 +1058,20 @@ func (u *AnnouncementUpsertBulk) UpdateStatus() *AnnouncementUpsertBulk {
})
}
// SetNotifyMode sets the "notify_mode" field.
func (u *AnnouncementUpsertBulk) SetNotifyMode(v string) *AnnouncementUpsertBulk {
return u.Update(func(s *AnnouncementUpsert) {
s.SetNotifyMode(v)
})
}
// UpdateNotifyMode sets the "notify_mode" field to the value that was provided on create.
func (u *AnnouncementUpsertBulk) UpdateNotifyMode() *AnnouncementUpsertBulk {
return u.Update(func(s *AnnouncementUpsert) {
s.UpdateNotifyMode()
})
}
// SetTargeting sets the "targeting" field.
func (u *AnnouncementUpsertBulk) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpsertBulk {
return u.Update(func(s *AnnouncementUpsert) {

View File

@@ -72,6 +72,20 @@ func (_u *AnnouncementUpdate) SetNillableStatus(v *string) *AnnouncementUpdate {
return _u
}
// SetNotifyMode sets the "notify_mode" field.
func (_u *AnnouncementUpdate) SetNotifyMode(v string) *AnnouncementUpdate {
_u.mutation.SetNotifyMode(v)
return _u
}
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
func (_u *AnnouncementUpdate) SetNillableNotifyMode(v *string) *AnnouncementUpdate {
if v != nil {
_u.SetNotifyMode(*v)
}
return _u
}
// SetTargeting sets the "targeting" field.
func (_u *AnnouncementUpdate) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdate {
_u.mutation.SetTargeting(v)
@@ -286,6 +300,11 @@ func (_u *AnnouncementUpdate) check() error {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
if v, ok := _u.mutation.NotifyMode(); ok {
if err := announcement.NotifyModeValidator(v); err != nil {
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
}
}
return nil
}
@@ -310,6 +329,9 @@ func (_u *AnnouncementUpdate) sqlSave(ctx context.Context) (_node int, err error
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.NotifyMode(); ok {
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
}
if value, ok := _u.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
}
@@ -456,6 +478,20 @@ func (_u *AnnouncementUpdateOne) SetNillableStatus(v *string) *AnnouncementUpdat
return _u
}
// SetNotifyMode sets the "notify_mode" field.
func (_u *AnnouncementUpdateOne) SetNotifyMode(v string) *AnnouncementUpdateOne {
_u.mutation.SetNotifyMode(v)
return _u
}
// SetNillableNotifyMode sets the "notify_mode" field if the given value is not nil.
func (_u *AnnouncementUpdateOne) SetNillableNotifyMode(v *string) *AnnouncementUpdateOne {
if v != nil {
_u.SetNotifyMode(*v)
}
return _u
}
// SetTargeting sets the "targeting" field.
func (_u *AnnouncementUpdateOne) SetTargeting(v domain.AnnouncementTargeting) *AnnouncementUpdateOne {
_u.mutation.SetTargeting(v)
@@ -683,6 +719,11 @@ func (_u *AnnouncementUpdateOne) check() error {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Announcement.status": %w`, err)}
}
}
if v, ok := _u.mutation.NotifyMode(); ok {
if err := announcement.NotifyModeValidator(v); err != nil {
return &ValidationError{Name: "notify_mode", err: fmt.Errorf(`ent: validator failed for field "Announcement.notify_mode": %w`, err)}
}
}
return nil
}
@@ -724,6 +765,9 @@ func (_u *AnnouncementUpdateOne) sqlSave(ctx context.Context) (_node *Announceme
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(announcement.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.NotifyMode(); ok {
_spec.SetField(announcement.FieldNotifyMode, field.TypeString, value)
}
if value, ok := _u.mutation.Targeting(); ok {
_spec.SetField(announcement.FieldTargeting, field.TypeJSON, value)
}

View File

@@ -78,6 +78,10 @@ type Group struct {
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// 分组显示排序,数值越小越靠前
SortOrder int `json:"sort_order,omitempty"`
// 是否允许 /v1/messages 调度到此 OpenAI 分组
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
// 默认映射模型 ID当账号级映射找不到时使用此值
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -186,13 +190,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
values[i] = new(sql.NullString)
case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt:
values[i] = new(sql.NullTime)
@@ -415,6 +419,18 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.SortOrder = int(value.Int64)
}
case group.FieldAllowMessagesDispatch:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field allow_messages_dispatch", values[i])
} else if value.Valid {
_m.AllowMessagesDispatch = value.Bool
}
case group.FieldDefaultMappedModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i])
} else if value.Valid {
_m.DefaultMappedModel = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -608,6 +624,12 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("sort_order=")
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
builder.WriteString(", ")
builder.WriteString("allow_messages_dispatch=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch))
builder.WriteString(", ")
builder.WriteString("default_mapped_model=")
builder.WriteString(_m.DefaultMappedModel)
builder.WriteByte(')')
return builder.String()
}

View File

@@ -75,6 +75,10 @@ const (
FieldSupportedModelScopes = "supported_model_scopes"
// FieldSortOrder holds the string denoting the sort_order field in the database.
FieldSortOrder = "sort_order"
// FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database.
FieldAllowMessagesDispatch = "allow_messages_dispatch"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel = "default_mapped_model"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -180,6 +184,8 @@ var Columns = []string{
FieldMcpXMLInject,
FieldSupportedModelScopes,
FieldSortOrder,
FieldAllowMessagesDispatch,
FieldDefaultMappedModel,
}
var (
@@ -247,6 +253,12 @@ var (
DefaultSupportedModelScopes []string
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
DefaultSortOrder int
// DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field.
DefaultAllowMessagesDispatch bool
// DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field.
DefaultDefaultMappedModel string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator func(string) error
)
// OrderOption defines the ordering options for the Group queries.
@@ -397,6 +409,16 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
}
// ByAllowMessagesDispatch orders the results by the allow_messages_dispatch field.
func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc()
}
// ByDefaultMappedModel orders the results by the default_mapped_model field.
func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -195,6 +195,16 @@ func SortOrder(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
}
// AllowMessagesDispatch applies equality check predicate on the "allow_messages_dispatch" field. It's identical to AllowMessagesDispatchEQ.
func AllowMessagesDispatch(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
}
// DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ.
func DefaultMappedModel(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1470,6 +1480,81 @@ func SortOrderLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
}
// AllowMessagesDispatchEQ applies the EQ predicate on the "allow_messages_dispatch" field.
func AllowMessagesDispatchEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v))
}
// AllowMessagesDispatchNEQ applies the NEQ predicate on the "allow_messages_dispatch" field.
func AllowMessagesDispatchNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v))
}
// DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field.
func DefaultMappedModelEQ(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
// DefaultMappedModelNEQ applies the NEQ predicate on the "default_mapped_model" field.
func DefaultMappedModelNEQ(v string) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldDefaultMappedModel, v))
}
// DefaultMappedModelIn applies the In predicate on the "default_mapped_model" field.
func DefaultMappedModelIn(vs ...string) predicate.Group {
return predicate.Group(sql.FieldIn(FieldDefaultMappedModel, vs...))
}
// DefaultMappedModelNotIn applies the NotIn predicate on the "default_mapped_model" field.
func DefaultMappedModelNotIn(vs ...string) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldDefaultMappedModel, vs...))
}
// DefaultMappedModelGT applies the GT predicate on the "default_mapped_model" field.
func DefaultMappedModelGT(v string) predicate.Group {
return predicate.Group(sql.FieldGT(FieldDefaultMappedModel, v))
}
// DefaultMappedModelGTE applies the GTE predicate on the "default_mapped_model" field.
func DefaultMappedModelGTE(v string) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldDefaultMappedModel, v))
}
// DefaultMappedModelLT applies the LT predicate on the "default_mapped_model" field.
func DefaultMappedModelLT(v string) predicate.Group {
return predicate.Group(sql.FieldLT(FieldDefaultMappedModel, v))
}
// DefaultMappedModelLTE applies the LTE predicate on the "default_mapped_model" field.
func DefaultMappedModelLTE(v string) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldDefaultMappedModel, v))
}
// DefaultMappedModelContains applies the Contains predicate on the "default_mapped_model" field.
func DefaultMappedModelContains(v string) predicate.Group {
return predicate.Group(sql.FieldContains(FieldDefaultMappedModel, v))
}
// DefaultMappedModelHasPrefix applies the HasPrefix predicate on the "default_mapped_model" field.
func DefaultMappedModelHasPrefix(v string) predicate.Group {
return predicate.Group(sql.FieldHasPrefix(FieldDefaultMappedModel, v))
}
// DefaultMappedModelHasSuffix applies the HasSuffix predicate on the "default_mapped_model" field.
func DefaultMappedModelHasSuffix(v string) predicate.Group {
return predicate.Group(sql.FieldHasSuffix(FieldDefaultMappedModel, v))
}
// DefaultMappedModelEqualFold applies the EqualFold predicate on the "default_mapped_model" field.
func DefaultMappedModelEqualFold(v string) predicate.Group {
return predicate.Group(sql.FieldEqualFold(FieldDefaultMappedModel, v))
}
// DefaultMappedModelContainsFold applies the ContainsFold predicate on the "default_mapped_model" field.
func DefaultMappedModelContainsFold(v string) predicate.Group {
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -424,6 +424,34 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
return _c
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (_c *GroupCreate) SetAllowMessagesDispatch(v bool) *GroupCreate {
_c.mutation.SetAllowMessagesDispatch(v)
return _c
}
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate {
if v != nil {
_c.SetAllowMessagesDispatch(*v)
}
return _c
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate {
_c.mutation.SetDefaultMappedModel(v)
return _c
}
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
if v != nil {
_c.SetDefaultMappedModel(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -613,6 +641,14 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultSortOrder
_c.mutation.SetSortOrder(v)
}
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
v := group.DefaultAllowMessagesDispatch
_c.mutation.SetAllowMessagesDispatch(v)
}
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
v := group.DefaultDefaultMappedModel
_c.mutation.SetDefaultMappedModel(v)
}
return nil
}
@@ -683,6 +719,17 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.SortOrder(); !ok {
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
}
if _, ok := _c.mutation.AllowMessagesDispatch(); !ok {
return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)}
}
if _, ok := _c.mutation.DefaultMappedModel(); !ok {
return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)}
}
if v, ok := _c.mutation.DefaultMappedModel(); ok {
if err := group.DefaultMappedModelValidator(v); err != nil {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
}
}
return nil
}
@@ -830,6 +877,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
_node.SortOrder = value
}
if value, ok := _c.mutation.AllowMessagesDispatch(); ok {
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
_node.AllowMessagesDispatch = value
}
if value, ok := _c.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
_node.DefaultMappedModel = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1520,6 +1575,30 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
return u
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (u *GroupUpsert) SetAllowMessagesDispatch(v bool) *GroupUpsert {
u.Set(group.FieldAllowMessagesDispatch, v)
return u
}
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert {
u.SetExcluded(group.FieldAllowMessagesDispatch)
return u
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert {
u.Set(group.FieldDefaultMappedModel, v)
return u
}
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
u.SetExcluded(group.FieldDefaultMappedModel)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2188,6 +2267,34 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
})
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (u *GroupUpsertOne) SetAllowMessagesDispatch(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetAllowMessagesDispatch(v)
})
}
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateAllowMessagesDispatch()
})
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetDefaultMappedModel(v)
})
}
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateDefaultMappedModel()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -3022,6 +3129,34 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
})
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (u *GroupUpsertBulk) SetAllowMessagesDispatch(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetAllowMessagesDispatch(v)
})
}
// UpdateAllowMessagesDispatch sets the "allow_messages_dispatch" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateAllowMessagesDispatch()
})
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetDefaultMappedModel(v)
})
}
// UpdateDefaultMappedModel sets the "default_mapped_model" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateDefaultMappedModel()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -625,6 +625,34 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
return _u
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (_u *GroupUpdate) SetAllowMessagesDispatch(v bool) *GroupUpdate {
_u.mutation.SetAllowMessagesDispatch(v)
return _u
}
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate {
if v != nil {
_u.SetAllowMessagesDispatch(*v)
}
return _u
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate {
_u.mutation.SetDefaultMappedModel(v)
return _u
}
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
if v != nil {
_u.SetDefaultMappedModel(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -910,6 +938,11 @@ func (_u *GroupUpdate) check() error {
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
}
}
if v, ok := _u.mutation.DefaultMappedModel(); ok {
if err := group.DefaultMappedModelValidator(v); err != nil {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
}
}
return nil
}
@@ -1110,6 +1143,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedSortOrder(); ok {
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
}
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -2014,6 +2053,34 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
return _u
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (_u *GroupUpdateOne) SetAllowMessagesDispatch(v bool) *GroupUpdateOne {
_u.mutation.SetAllowMessagesDispatch(v)
return _u
}
// SetNillableAllowMessagesDispatch sets the "allow_messages_dispatch" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetAllowMessagesDispatch(*v)
}
return _u
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne {
_u.mutation.SetDefaultMappedModel(v)
return _u
}
// SetNillableDefaultMappedModel sets the "default_mapped_model" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateOne {
if v != nil {
_u.SetDefaultMappedModel(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2312,6 +2379,11 @@ func (_u *GroupUpdateOne) check() error {
return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)}
}
}
if v, ok := _u.mutation.DefaultMappedModel(); ok {
if err := group.DefaultMappedModelValidator(v); err != nil {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
}
}
return nil
}
@@ -2529,6 +2601,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.AddedSortOrder(); ok {
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AllowMessagesDispatch(); ok {
_spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value)
}
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -251,6 +251,7 @@ var (
{Name: "title", Type: field.TypeString, Size: 200},
{Name: "content", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "draft"},
{Name: "notify_mode", Type: field.TypeString, Size: 20, Default: "silent"},
{Name: "targeting", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "starts_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "ends_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -273,17 +274,17 @@ var (
{
Name: "announcement_created_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[9]},
Columns: []*schema.Column{AnnouncementsColumns[10]},
},
{
Name: "announcement_starts_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[5]},
Columns: []*schema.Column{AnnouncementsColumns[6]},
},
{
Name: "announcement_ends_at",
Unique: false,
Columns: []*schema.Column{AnnouncementsColumns[6]},
Columns: []*schema.Column{AnnouncementsColumns[7]},
},
},
}
@@ -407,6 +408,8 @@ var (
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{

View File

@@ -5167,6 +5167,7 @@ type AnnouncementMutation struct {
title *string
content *string
status *string
notify_mode *string
targeting *domain.AnnouncementTargeting
starts_at *time.Time
ends_at *time.Time
@@ -5391,6 +5392,42 @@ func (m *AnnouncementMutation) ResetStatus() {
m.status = nil
}
// SetNotifyMode sets the "notify_mode" field.
func (m *AnnouncementMutation) SetNotifyMode(s string) {
m.notify_mode = &s
}
// NotifyMode returns the value of the "notify_mode" field in the mutation.
func (m *AnnouncementMutation) NotifyMode() (r string, exists bool) {
v := m.notify_mode
if v == nil {
return
}
return *v, true
}
// OldNotifyMode returns the old "notify_mode" field's value of the Announcement entity.
// If the Announcement object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AnnouncementMutation) OldNotifyMode(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldNotifyMode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldNotifyMode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldNotifyMode: %w", err)
}
return oldValue.NotifyMode, nil
}
// ResetNotifyMode resets all changes to the "notify_mode" field.
func (m *AnnouncementMutation) ResetNotifyMode() {
m.notify_mode = nil
}
// SetTargeting sets the "targeting" field.
func (m *AnnouncementMutation) SetTargeting(dt domain.AnnouncementTargeting) {
m.targeting = &dt
@@ -5838,7 +5875,7 @@ func (m *AnnouncementMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AnnouncementMutation) Fields() []string {
fields := make([]string, 0, 10)
fields := make([]string, 0, 11)
if m.title != nil {
fields = append(fields, announcement.FieldTitle)
}
@@ -5848,6 +5885,9 @@ func (m *AnnouncementMutation) Fields() []string {
if m.status != nil {
fields = append(fields, announcement.FieldStatus)
}
if m.notify_mode != nil {
fields = append(fields, announcement.FieldNotifyMode)
}
if m.targeting != nil {
fields = append(fields, announcement.FieldTargeting)
}
@@ -5883,6 +5923,8 @@ func (m *AnnouncementMutation) Field(name string) (ent.Value, bool) {
return m.Content()
case announcement.FieldStatus:
return m.Status()
case announcement.FieldNotifyMode:
return m.NotifyMode()
case announcement.FieldTargeting:
return m.Targeting()
case announcement.FieldStartsAt:
@@ -5912,6 +5954,8 @@ func (m *AnnouncementMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldContent(ctx)
case announcement.FieldStatus:
return m.OldStatus(ctx)
case announcement.FieldNotifyMode:
return m.OldNotifyMode(ctx)
case announcement.FieldTargeting:
return m.OldTargeting(ctx)
case announcement.FieldStartsAt:
@@ -5956,6 +6000,13 @@ func (m *AnnouncementMutation) SetField(name string, value ent.Value) error {
}
m.SetStatus(v)
return nil
case announcement.FieldNotifyMode:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetNotifyMode(v)
return nil
case announcement.FieldTargeting:
v, ok := value.(domain.AnnouncementTargeting)
if !ok {
@@ -6123,6 +6174,9 @@ func (m *AnnouncementMutation) ResetField(name string) error {
case announcement.FieldStatus:
m.ResetStatus()
return nil
case announcement.FieldNotifyMode:
m.ResetNotifyMode()
return nil
case announcement.FieldTargeting:
m.ResetTargeting()
return nil
@@ -8196,6 +8250,8 @@ type GroupMutation struct {
appendsupported_model_scopes []string
sort_order *int
addsort_order *int
allow_messages_dispatch *bool
default_mapped_model *string
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -9940,6 +9996,78 @@ func (m *GroupMutation) ResetSortOrder() {
m.addsort_order = nil
}
// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
m.allow_messages_dispatch = &b
}
// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
v := m.allow_messages_dispatch
if v == nil {
return
}
return *v, true
}
// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
}
return oldValue.AllowMessagesDispatch, nil
}
// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
func (m *GroupMutation) ResetAllowMessagesDispatch() {
m.allow_messages_dispatch = nil
}
// SetDefaultMappedModel sets the "default_mapped_model" field.
func (m *GroupMutation) SetDefaultMappedModel(s string) {
m.default_mapped_model = &s
}
// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
v := m.default_mapped_model
if v == nil {
return
}
return *v, true
}
// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
}
return oldValue.DefaultMappedModel, nil
}
// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10298,7 +10426,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 31)
fields := make([]string, 0, 32)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10389,6 +10517,12 @@ func (m *GroupMutation) Fields() []string {
if m.sort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
if m.allow_messages_dispatch != nil {
fields = append(fields, group.FieldAllowMessagesDispatch)
}
if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel)
}
return fields
}
@@ -10457,6 +10591,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.SupportedModelScopes()
case group.FieldSortOrder:
return m.SortOrder()
case group.FieldAllowMessagesDispatch:
return m.AllowMessagesDispatch()
case group.FieldDefaultMappedModel:
return m.DefaultMappedModel()
}
return nil, false
}
@@ -10526,6 +10664,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldSupportedModelScopes(ctx)
case group.FieldSortOrder:
return m.OldSortOrder(ctx)
case group.FieldAllowMessagesDispatch:
return m.OldAllowMessagesDispatch(ctx)
case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10745,6 +10887,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetSortOrder(v)
return nil
case group.FieldAllowMessagesDispatch:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAllowMessagesDispatch(v)
return nil
case group.FieldDefaultMappedModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetDefaultMappedModel(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -11172,6 +11328,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldSortOrder:
m.ResetSortOrder()
return nil
case group.FieldAllowMessagesDispatch:
m.ResetAllowMessagesDispatch()
return nil
case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}

View File

@@ -277,12 +277,18 @@ func init() {
announcement.DefaultStatus = announcementDescStatus.Default.(string)
// announcement.StatusValidator is a validator for the "status" field. It is called by the builders before save.
announcement.StatusValidator = announcementDescStatus.Validators[0].(func(string) error)
// announcementDescNotifyMode is the schema descriptor for notify_mode field.
announcementDescNotifyMode := announcementFields[3].Descriptor()
// announcement.DefaultNotifyMode holds the default value on creation for the notify_mode field.
announcement.DefaultNotifyMode = announcementDescNotifyMode.Default.(string)
// announcement.NotifyModeValidator is a validator for the "notify_mode" field. It is called by the builders before save.
announcement.NotifyModeValidator = announcementDescNotifyMode.Validators[0].(func(string) error)
// announcementDescCreatedAt is the schema descriptor for created_at field.
announcementDescCreatedAt := announcementFields[8].Descriptor()
announcementDescCreatedAt := announcementFields[9].Descriptor()
// announcement.DefaultCreatedAt holds the default value on creation for the created_at field.
announcement.DefaultCreatedAt = announcementDescCreatedAt.Default.(func() time.Time)
// announcementDescUpdatedAt is the schema descriptor for updated_at field.
announcementDescUpdatedAt := announcementFields[9].Descriptor()
announcementDescUpdatedAt := announcementFields[10].Descriptor()
// announcement.DefaultUpdatedAt holds the default value on creation for the updated_at field.
announcement.DefaultUpdatedAt = announcementDescUpdatedAt.Default.(func() time.Time)
// announcement.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
@@ -447,6 +453,16 @@ func init() {
groupDescSortOrder := groupFields[26].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
groupDescDefaultMappedModel := groupFields[28].Descriptor()
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0

View File

@@ -41,6 +41,10 @@ func (Announcement) Fields() []ent.Field {
MaxLen(20).
Default(domain.AnnouncementStatusDraft).
Comment("状态: draft, active, archived"),
field.String("notify_mode").
MaxLen(20).
Default(domain.AnnouncementNotifyModeSilent).
Comment("通知模式: silent(仅铃铛), popup(弹窗提醒)"),
field.JSON("targeting", domain.AnnouncementTargeting{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).

View File

@@ -148,6 +148,15 @@ func (Group) Fields() []ent.Field {
field.Int("sort_order").
Default(0).
Comment("分组显示排序,数值越小越靠前"),
// OpenAI Messages 调度配置 (added by migration 069)
field.Bool("allow_messages_dispatch").
Default(false).
Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"),
field.String("default_mapped_model").
MaxLen(100).
Default("").
Comment("默认映射模型 ID当账号级映射找不到时使用此值"),
}
}

View File

@@ -1,12 +1,13 @@
module github.com/Wei-Shaw/sub2api
go 1.25.7
go 1.26.1
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2 v1.41.2
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
@@ -38,8 +39,6 @@ require (
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.40.0
google.golang.org/grpc v1.75.1
google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
@@ -53,7 +52,6 @@ require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
@@ -109,7 +107,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -169,6 +166,7 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
@@ -178,8 +176,8 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect

View File

@@ -124,8 +124,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@@ -182,7 +180,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -202,8 +199,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -286,10 +281,6 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

View File

@@ -1402,7 +1402,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)

View File

@@ -13,6 +13,11 @@ const (
AnnouncementStatusArchived = "archived"
)
const (
AnnouncementNotifyModeSilent = "silent"
AnnouncementNotifyModePopup = "popup"
)
const (
AnnouncementConditionTypeSubscription = "subscription"
AnnouncementConditionTypeBalance = "balance"
@@ -195,17 +200,18 @@ func (c AnnouncementCondition) validate() error {
}
type Announcement struct {
ID int64
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
CreatedBy *int64
UpdatedBy *int64
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Title string
Content string
Status string
NotifyMode string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
CreatedBy *int64
UpdatedBy *int64
CreatedAt time.Time
UpdatedAt time.Time
}
func (a *Announcement) IsActiveAt(now time.Time) bool {

View File

@@ -122,7 +122,7 @@ type UpdateAccountRequest struct {
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -288,48 +288,32 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 窗口费用获取lite 模式从快照缓存读取,非 lite 模式执行 PostgreSQL 查询后写入缓存
// 始终获取窗口费用(PostgreSQL 聚合查询)
if len(windowCostAccountIDs) > 0 {
if lite {
// lite 模式:尝试从快照缓存读取
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
if cached, ok := accountWindowCostCache.Get(cacheKey); ok {
if costs, ok := cached.Payload.(map[int64]float64); ok {
windowCosts = costs
}
}
// 缓存未命中则 windowCosts 保持 nil仅发生在服务刚启动时
} else {
// 非 lite 模式:执行 PostgreSQL 聚合查询(高开销)
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
_ = g.Wait()
// 查询完毕后写入快照缓存,供 lite 模式使用
cacheKey := buildWindowCostCacheKey(windowCostAccountIDs)
accountWindowCostCache.Set(cacheKey, windowCosts)
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
}
// Build response with concurrency info
@@ -676,6 +660,42 @@ func (h *AccountHandler) Test(c *gin.Context) {
// Error already sent via SSE, just log
return
}
if h.rateLimitService != nil {
if _, err := h.rateLimitService.RecoverAccountAfterSuccessfulTest(c.Request.Context(), accountID); err != nil {
_ = c.Error(err)
}
}
}
// RecoverState handles unified recovery of recoverable account runtime state.
// POST /api/v1/admin/accounts/:id/recover-state
func (h *AccountHandler) RecoverState(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
if h.rateLimitService == nil {
response.Error(c, http.StatusServiceUnavailable, "Rate limit service unavailable")
return
}
if _, err := h.rateLimitService.RecoverAccountState(c.Request.Context(), accountID, service.AccountRecoveryOptions{
InvalidateToken: true,
}); err != nil {
response.ErrorFrom(c, err)
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)

View File

@@ -1,25 +0,0 @@
package admin
import (
"strconv"
"strings"
"time"
)
var accountWindowCostCache = newSnapshotCache(30 * time.Second)
func buildWindowCostCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_window_cost_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_window_cost:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}

View File

@@ -27,21 +27,23 @@ func NewAnnouncementHandler(announcementService *service.AnnouncementService) *A
}
type CreateAnnouncementRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
}
type UpdateAnnouncementRequest struct {
Title *string `json:"title"`
Content *string `json:"content"`
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
Targeting *service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
Title *string `json:"title"`
Content *string `json:"content"`
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
Targeting *service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
}
// List handles listing announcements with filters
@@ -110,11 +112,12 @@ func (h *AnnouncementHandler) Create(c *gin.Context) {
}
input := &service.CreateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
Title: req.Title,
Content: req.Content,
Status: req.Status,
NotifyMode: req.NotifyMode,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil && *req.StartsAt > 0 {
@@ -157,11 +160,12 @@ func (h *AnnouncementHandler) Update(c *gin.Context) {
}
input := &service.UpdateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
Targeting: req.Targeting,
ActorID: &subject.UserID,
Title: req.Title,
Content: req.Content,
Status: req.Status,
NotifyMode: req.NotifyMode,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil {

View File

@@ -53,6 +53,9 @@ type CreateGroupRequest struct {
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -88,6 +91,9 @@ type UpdateGroupRequest struct {
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
DefaultMappedModel *string `json:"default_mapped_model"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -203,6 +209,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
DefaultMappedModel: req.DefaultMappedModel,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -254,6 +262,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
DefaultMappedModel: req.DefaultMappedModel,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {

View File

@@ -25,6 +25,7 @@ type createScheduledTestPlanRequest struct {
CronExpression string `json:"cron_expression" binding:"required"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
AutoRecover *bool `json:"auto_recover"`
}
type updateScheduledTestPlanRequest struct {
@@ -32,6 +33,7 @@ type updateScheduledTestPlanRequest struct {
CronExpression string `json:"cron_expression"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
AutoRecover *bool `json:"auto_recover"`
}
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
@@ -68,6 +70,9 @@ func (h *ScheduledTestHandler) Create(c *gin.Context) {
if req.Enabled != nil {
plan.Enabled = *req.Enabled
}
if req.AutoRecover != nil {
plan.AutoRecover = *req.AutoRecover
}
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
if err != nil {
@@ -109,6 +114,9 @@ func (h *ScheduledTestHandler) Update(c *gin.Context) {
if req.MaxResults > 0 {
existing.MaxResults = req.MaxResults
}
if req.AutoRecover != nil {
existing.AutoRecover = *req.AutoRecover
}
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
if err != nil {

View File

@@ -1348,6 +1348,63 @@ func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
response.Success(c, gin.H{"message": "S3 连接成功"})
}
// GetRectifierSettings 获取请求整流器配置
// GET /api/v1/admin/settings/rectifier
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
settings, err := h.settingService.GetRectifierSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.RectifierSettings{
Enabled: settings.Enabled,
ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled,
ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled,
})
}
// UpdateRectifierSettingsRequest 更新整流器配置请求
type UpdateRectifierSettingsRequest struct {
Enabled bool `json:"enabled"`
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
}
// UpdateRectifierSettings 更新请求整流器配置
// PUT /api/v1/admin/settings/rectifier
func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) {
var req UpdateRectifierSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.RectifierSettings{
Enabled: req.Enabled,
ThinkingSignatureEnabled: req.ThinkingSignatureEnabled,
ThinkingBudgetEnabled: req.ThinkingBudgetEnabled,
}
if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetRectifierSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.RectifierSettings{
Enabled: updatedSettings.Enabled,
ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled,
ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled,
})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`

View File

@@ -7,10 +7,11 @@ import (
)
type Announcement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
NotifyMode string `json:"notify_mode"`
Targeting service.AnnouncementTargeting `json:"targeting"`
@@ -25,9 +26,10 @@ type Announcement struct {
}
type UserAnnouncement struct {
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
ID int64 `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
NotifyMode string `json:"notify_mode"`
StartsAt *time.Time `json:"starts_at,omitempty"`
EndsAt *time.Time `json:"ends_at,omitempty"`
@@ -43,17 +45,18 @@ func AnnouncementFromService(a *service.Announcement) *Announcement {
return nil
}
return &Announcement{
ID: a.ID,
Title: a.Title,
Content: a.Content,
Status: a.Status,
Targeting: a.Targeting,
StartsAt: a.StartsAt,
EndsAt: a.EndsAt,
CreatedBy: a.CreatedBy,
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
ID: a.ID,
Title: a.Title,
Content: a.Content,
Status: a.Status,
NotifyMode: a.NotifyMode,
Targeting: a.Targeting,
StartsAt: a.StartsAt,
EndsAt: a.EndsAt,
CreatedBy: a.CreatedBy,
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
}
}
@@ -62,13 +65,14 @@ func UserAnnouncementFromService(a *service.UserAnnouncement) *UserAnnouncement
return nil
}
return &UserAnnouncement{
ID: a.Announcement.ID,
Title: a.Announcement.Title,
Content: a.Announcement.Content,
StartsAt: a.Announcement.StartsAt,
EndsAt: a.Announcement.EndsAt,
ReadAt: a.ReadAt,
CreatedAt: a.Announcement.CreatedAt,
UpdatedAt: a.Announcement.UpdatedAt,
ID: a.Announcement.ID,
Title: a.Announcement.Title,
Content: a.Announcement.Content,
NotifyMode: a.Announcement.NotifyMode,
StartsAt: a.Announcement.StartsAt,
EndsAt: a.Announcement.EndsAt,
ReadAt: a.ReadAt,
CreatedAt: a.Announcement.CreatedAt,
UpdatedAt: a.Announcement.UpdatedAt,
}
}

View File

@@ -71,7 +71,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
}
return &APIKey{
out := &APIKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
@@ -89,15 +89,28 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
RateLimit5h: k.RateLimit5h,
RateLimit1d: k.RateLimit1d,
RateLimit7d: k.RateLimit7d,
Usage5h: k.Usage5h,
Usage1d: k.Usage1d,
Usage7d: k.Usage7d,
Usage5h: k.EffectiveUsage5h(),
Usage1d: k.EffectiveUsage1d(),
Usage7d: k.EffectiveUsage7d(),
Window5hStart: k.Window5hStart,
Window1dStart: k.Window1dStart,
Window7dStart: k.Window7dStart,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
if k.Window5hStart != nil && !service.IsWindowExpired(k.Window5hStart, service.RateLimitWindow5h) {
t := k.Window5hStart.Add(service.RateLimitWindow5h)
out.Reset5hAt = &t
}
if k.Window1dStart != nil && !service.IsWindowExpired(k.Window1dStart, service.RateLimitWindow1d) {
t := k.Window1dStart.Add(service.RateLimitWindow1d)
out.Reset1dAt = &t
}
if k.Window7dStart != nil && !service.IsWindowExpired(k.Window7dStart, service.RateLimitWindow7d) {
t := k.Window7dStart.Add(service.RateLimitWindow7d)
out.Reset7dAt = &t
}
return out
}
func GroupFromServiceShallow(g *service.Group) *Group {
@@ -126,6 +139,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
@@ -164,6 +178,7 @@ func groupFromServiceBase(g *service.Group) Group {
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
AllowMessagesDispatch: g.AllowMessagesDispatch,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
@@ -253,11 +268,19 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a.Type == service.AccountTypeAPIKey {
if limit := a.GetQuotaLimit(); limit > 0 {
out.QuotaLimit = &limit
}
used := a.GetQuotaUsed()
if out.QuotaLimit != nil {
used := a.GetQuotaUsed()
out.QuotaUsed = &used
}
if limit := a.GetQuotaDailyLimit(); limit > 0 {
out.QuotaDailyLimit = &limit
used := a.GetQuotaDailyUsed()
out.QuotaDailyUsed = &used
}
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
out.QuotaWeeklyLimit = &limit
used := a.GetQuotaWeeklyUsed()
out.QuotaWeeklyUsed = &used
}
}
return out
@@ -473,6 +496,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,

View File

@@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) {
t.Parallel()
require.Nil(t, requestTypeStringPtr(nil))
}
func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
t.Parallel()
serviceTier := "priority"
log := &service.UsageLog{
RequestID: "req_3",
Model: "gpt-5.4",
ServiceTier: &serviceTier,
AccountRateMultiplier: f64Ptr(1.5),
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
require.NotNil(t, userDTO.ServiceTier)
require.Equal(t, serviceTier, *userDTO.ServiceTier)
require.NotNil(t, adminDTO.ServiceTier)
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
require.NotNil(t, adminDTO.AccountRateMultiplier)
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
}
func f64Ptr(value float64) *float64 {
return &value
}

View File

@@ -161,6 +161,13 @@ type StreamTimeoutSettings struct {
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}
// RectifierSettings 请求整流器配置 DTO
type RectifierSettings struct {
Enabled bool `json:"enabled"`
ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"`
ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {

View File

@@ -57,6 +57,9 @@ type APIKey struct {
Window5hStart *time.Time `json:"window_5h_start"`
Window1dStart *time.Time `json:"window_1d_start"`
Window7dStart *time.Time `json:"window_7d_start"`
Reset5hAt *time.Time `json:"reset_5h_at,omitempty"`
Reset1dAt *time.Time `json:"reset_1d_at,omitempty"`
Reset7dAt *time.Time `json:"reset_7d_at,omitempty"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
@@ -96,6 +99,9 @@ type Group struct {
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -112,6 +118,9 @@ type AdminGroup struct {
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -187,8 +196,12 @@ type Account struct {
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
// API Key 账号配额限制
QuotaLimit *float64 `json:"quota_limit,omitempty"`
QuotaUsed *float64 `json:"quota_used,omitempty"`
QuotaLimit *float64 `json:"quota_limit,omitempty"`
QuotaUsed *float64 `json:"quota_used,omitempty"`
QuotaDailyLimit *float64 `json:"quota_daily_limit,omitempty"`
QuotaDailyUsed *float64 `json:"quota_daily_used,omitempty"`
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -309,6 +322,8 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"`

View File

@@ -30,7 +30,7 @@ const (
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
maxSameAccountRetries = 3
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。

View File

@@ -291,35 +291,31 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
require.Less(t, elapsed, 2*time.Second)
})
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
t.Run("达到最大重试次数前均返回FailoverContinue", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SameAccountRetryCount[100])
for i := 1; i <= maxSameAccountRetries; i++ {
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, i, fs.SameAccountRetryCount[100])
}
// 第二次
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
require.Empty(t, mock.calls, "达到最大重试次数前均不应调用 TempUnschedule")
})
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
t.Run("超过最大重试次数后触发TempUnschedule并切换", func(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, false)
err := newTestFailoverErr(400, true, false)
// 第一次、第二次重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, 2, fs.SameAccountRetryCount[100])
for i := 0; i < maxSameAccountRetries; i++ {
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
}
require.Equal(t, maxSameAccountRetries, fs.SameAccountRetryCount[100])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
// 第 maxSameAccountRetries+1 次:重试耗尽,应切换账号
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
@@ -354,13 +350,14 @@ func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
err := newTestFailoverErr(400, true, false)
// 耗尽账号 100 的重试
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
// 第三次: 重试耗尽 → 切换
for i := 0; i < maxSameAccountRetries; i++ {
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
}
// 第 maxSameAccountRetries+1 次: 重试耗尽 → 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
// 再次遇到账号 100计数仍为 2,条件不满足 → 直接切换
// 再次遇到账号 100计数仍为 maxSameAccountRetries,条件不满足 → 直接切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
require.Equal(t, FailoverContinue, action)
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
@@ -386,9 +383,10 @@ func TestHandleFailoverError_TempUnschedule(t *testing.T) {
fs := NewFailoverState(3, false)
err := newTestFailoverErr(502, true, false)
// 耗尽重试
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
for i := 0; i < maxSameAccountRetries; i++ {
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
}
// 再次触发时才会执行 TempUnschedule + 切换
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
require.Len(t, mock.calls, 1)
@@ -521,17 +519,16 @@ func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
mock := &mockTempUnscheduler{}
fs := NewFailoverState(3, true) // hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2
// 1. 账号 100 遇到可重试错误,同账号重试 maxSameAccountRetries
retryErr := newTestFailoverErr(400, true, false)
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
for i := 0; i < maxSameAccountRetries; i++ {
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
}
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
// 2. 账号 100 超过重试上限 → TempUnschedule + 切换
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
require.Equal(t, FailoverContinue, action)
require.Equal(t, 1, fs.SwitchCount)
require.Len(t, mock.calls, 1)

View File

@@ -971,34 +971,46 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
if err == nil && rateLimitData != nil {
var rateLimits []gin.H
if apiKey.RateLimit5h > 0 {
used := rateLimitData.Usage5h
rateLimits = append(rateLimits, gin.H{
used := rateLimitData.EffectiveUsage5h()
entry := gin.H{
"window": "5h",
"limit": apiKey.RateLimit5h,
"used": used,
"remaining": max(0, apiKey.RateLimit5h-used),
"window_start": rateLimitData.Window5hStart,
})
}
if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) {
entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h)
}
rateLimits = append(rateLimits, entry)
}
if apiKey.RateLimit1d > 0 {
used := rateLimitData.Usage1d
rateLimits = append(rateLimits, gin.H{
used := rateLimitData.EffectiveUsage1d()
entry := gin.H{
"window": "1d",
"limit": apiKey.RateLimit1d,
"used": used,
"remaining": max(0, apiKey.RateLimit1d-used),
"window_start": rateLimitData.Window1dStart,
})
}
if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) {
entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d)
}
rateLimits = append(rateLimits, entry)
}
if apiKey.RateLimit7d > 0 {
used := rateLimitData.Usage7d
rateLimits = append(rateLimits, gin.H{
used := rateLimitData.EffectiveUsage7d()
entry := gin.H{
"window": "7d",
"limit": apiKey.RateLimit7d,
"used": used,
"remaining": max(0, apiKey.RateLimit7d-used),
"window_start": rateLimitData.Window7dStart,
})
}
if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) {
entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d)
}
rateLimits = append(rateLimits, entry)
}
if len(rateLimits) > 0 {
resp["rate_limits"] = rateLimits

View File

@@ -155,6 +155,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // sessionLimitCache
nil, // rpmCache
nil, // digestStore
nil, // settingService
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。

View File

@@ -20,6 +20,7 @@ import (
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
@@ -118,6 +119,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
setOpsRequestContext(c, "", false, body)
sessionHashBody := body
if service.IsOpenAIResponsesCompactPathForTest(c) {
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
c.Set(service.OpenAICompactSessionSeedKeyForTest(), compactSeed)
}
normalizedCompactBody, normalizedCompact, compactErr := service.NormalizeOpenAICompactRequestBodyForTest(body)
if compactErr != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to normalize compact request body")
return
}
if normalizedCompact {
body = normalizedCompactBody
}
}
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
@@ -193,11 +208,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
for {
@@ -245,6 +261,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Float64("load_skew", scheduleDecision.LoadSkew),
)
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
@@ -274,6 +291,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
@@ -305,6 +341,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
if result != nil {
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
@@ -424,6 +463,352 @@ func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, sta
log.Warn("codex.remote_compact.failed")
}
// Messages handles Anthropic Messages API requests routed to OpenAI platform.
// POST /v1/messages (when group platform is OpenAI)
func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
streamStarted := false
defer h.recoverAnthropicMessagesPanic(c, &streamStarted)
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.anthropicErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.messages",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 检查分组是否允许 /v1/messages 调度
if apiKey.Group != nil && !apiKey.Group.AllowMessagesDispatch {
h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error",
"This group does not allow /v1/messages dispatch")
return
}
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.anthropicErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
if !gjson.ValidBytes(body) {
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.anthropicErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
// Anthropic 格式的请求在 metadata.user_id 中携带 session 标识,
// 而非 OpenAI 的 session_id/conversation_id headers。
// 从中派生 sessionHashsticky session和 promptCacheKeyupstream cache
if sessionHash == "" || promptCacheKey == "" {
if userID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()); userID != "" {
seed := reqModel + "-" + userID
if promptCacheKey == "" {
promptCacheKey = service.GenerateSessionUUID(seed)
}
if sessionHash == "" {
sessionHash = service.DeriveSessionHashFromSeed(seed)
}
}
}
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
for {
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
c.Set("openai_messages_fallback_model", "")
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
reqLog.Warn("openai_messages.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_messages_fallback_model", defaultModel)
}
}
if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
} else {
if lastFailoverErr != nil {
h.handleAnthropicFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
}
return
}
}
if selection == nil || selection.Account == nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai_messages.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
_ = scheduleDecision
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
// 如果使用了降级模型调度,强制使用降级模型
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_messages.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_messages.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
reqLog.Warn("openai_messages.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_messages.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_messages.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
// anthropicErrorResponse writes an error in Anthropic Messages API format.
func (h *OpenAIGatewayHandler) anthropicErrorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// anthropicStreamingAwareError handles errors that may occur during streaming,
// using Anthropic SSE error format.
func (h *OpenAIGatewayHandler) anthropicStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errPayload, _ := json.Marshal(gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errPayload) //nolint:errcheck
flusher.Flush()
}
return
}
h.anthropicErrorResponse(c, status, errType, message)
}
// handleAnthropicFailoverExhausted maps upstream failover errors to Anthropic format.
func (h *OpenAIGatewayHandler) handleAnthropicFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(failoverErr.StatusCode)
h.anthropicStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// ensureAnthropicErrorResponse writes a fallback Anthropic error if no response was written.
func (h *OpenAIGatewayHandler) ensureAnthropicErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.anthropicStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
return true
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true
@@ -840,6 +1225,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
if turnErr != nil || result == nil {
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
@@ -901,6 +1289,26 @@ func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStart
)
}
// recoverAnthropicMessagesPanic recovers from panics in the Anthropic Messages
// handler and returns an Anthropic-formatted error response.
func (h *OpenAIGatewayHandler) recoverAnthropicMessagesPanic(c *gin.Context, streamStarted *bool) {
recovered := recover()
if recovered == nil {
return
}
started := streamStarted != nil && *streamStarted
requestLogger(c, "handler.openai_gateway.messages").Error(
"openai.messages_panic_recovered",
zap.Bool("stream_started", started),
zap.Any("panic", recovered),
zap.ByteString("stack", debug.Stack()),
)
if !started {
h.anthropicErrorResponse(c, http.StatusInternalServerError, "api_error", "Internal server error")
}
}
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
missing := h.missingResponsesDependencies()
if len(missing) == 0 {
@@ -1106,6 +1514,14 @@ func setOpenAIClientTransportWS(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
}
func ensureOpenAIPoolModeSessionHash(sessionHash string, account *service.Account) string {
if sessionHash != "" || account == nil || !account.IsPoolMode() {
return sessionHash
}
// 为当前请求生成一次性粘性会话键,确保同账号重试不会重新负载均衡到其他账号。
return "openai-pool-retry-" + uuid.NewString()
}
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
gid := int64(0)
if groupID != nil {

View File

@@ -2207,7 +2207,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}

View File

@@ -445,6 +445,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
testutil.StubSessionLimitCache{},
nil, // rpmCache
nil, // digestStore
nil, // settingService
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}

View File

@@ -49,8 +49,8 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
var defaultUserAgentVersion = "1.19.6"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
var defaultUserAgentVersion = "1.20.4"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"

View File

@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {

View File

@@ -119,23 +119,33 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
return result.Bytes()
}
// Finish 结束处理,返回最终事件和用量
// Finish 结束处理,返回最终事件和用量
// 若整个流未收到任何可解析的上游数据messageStartSent == false
// 则不补发任何结束事件,防止客户端收到没有 message_start 的残缺流。
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
var result bytes.Buffer
if !p.messageStopSent {
_, _ = result.Write(p.emitFinish(""))
}
usage := &ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
if !p.messageStartSent {
return nil, usage
}
var result bytes.Buffer
if !p.messageStopSent {
_, _ = result.Write(p.emitFinish(""))
}
return result.Bytes(), usage
}
// MessageStartSent 报告流中是否已发出过 message_start 事件(即是否收到过有效的上游数据)
func (p *StreamingProcessor) MessageStartSent() bool {
return p.messageStartSent
}
// emitMessageStart 发送 message_start 事件
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
if p.messageStartSent {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,417 @@
package apicompat
import (
"encoding/json"
"fmt"
"strings"
)
// AnthropicToResponses converts an Anthropic Messages request directly into
// a Responses API request. This preserves fields that would be lost in a
// Chat Completions intermediary round-trip (e.g. thinking, cache_control,
// structured system prompts).
func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
input, err := convertAnthropicToResponsesInput(req.System, req.Messages)
if err != nil {
return nil, err
}
inputJSON, err := json.Marshal(input)
if err != nil {
return nil, err
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: req.Stream,
Include: []string{"reasoning.encrypted_content"},
}
storeFalse := false
out.Store = &storeFalse
if req.MaxTokens > 0 {
v := req.MaxTokens
if v < minMaxOutputTokens {
v = minMaxOutputTokens
}
out.MaxOutputTokens = &v
}
if len(req.Tools) > 0 {
out.Tools = convertAnthropicToolsToResponses(req.Tools)
}
// Determine reasoning effort: only output_config.effort controls the
// level; thinking.type is ignored. Default is xhigh when unset.
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
effort := "high" // default → maps to xhigh
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
effort = req.OutputConfig.Effort
}
out.Reasoning = &ResponsesReasoning{
Effort: mapAnthropicEffortToResponses(effort),
Summary: "auto",
}
// Convert tool_choice
if len(req.ToolChoice) > 0 {
tc, err := convertAnthropicToolChoiceToResponses(req.ToolChoice)
if err != nil {
return nil, fmt.Errorf("convert tool_choice: %w", err)
}
out.ToolChoice = tc
}
return out, nil
}
// convertAnthropicToolChoiceToResponses maps Anthropic tool_choice to Responses format.
//
// {"type":"auto"} → "auto"
// {"type":"any"} → "required"
// {"type":"none"} → "none"
// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
var tc struct {
Type string `json:"type"`
Name string `json:"name"`
}
if err := json.Unmarshal(raw, &tc); err != nil {
return nil, err
}
switch tc.Type {
case "auto":
return json.Marshal("auto")
case "any":
return json.Marshal("required")
case "none":
return json.Marshal("none")
case "tool":
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": tc.Name},
})
default:
// Pass through unknown types as-is
return raw, nil
}
}
// convertAnthropicToResponsesInput builds the Responses API input items array
// from the Anthropic system field and message list.
func convertAnthropicToResponsesInput(system json.RawMessage, msgs []AnthropicMessage) ([]ResponsesInputItem, error) {
var out []ResponsesInputItem
// System prompt → system role input item.
if len(system) > 0 {
sysText, err := parseAnthropicSystemPrompt(system)
if err != nil {
return nil, err
}
if sysText != "" {
content, _ := json.Marshal(sysText)
out = append(out, ResponsesInputItem{
Role: "system",
Content: content,
})
}
}
for _, m := range msgs {
items, err := anthropicMsgToResponsesItems(m)
if err != nil {
return nil, err
}
out = append(out, items...)
}
return out, nil
}
// parseAnthropicSystemPrompt handles the Anthropic system field which can be
// a plain string or an array of text blocks.
func parseAnthropicSystemPrompt(raw json.RawMessage) (string, error) {
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s, nil
}
var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err != nil {
return "", err
}
var parts []string
for _, b := range blocks {
if b.Type == "text" && b.Text != "" {
parts = append(parts, b.Text)
}
}
return strings.Join(parts, "\n\n"), nil
}
// anthropicMsgToResponsesItems converts a single Anthropic message into one
// or more Responses API input items.
func anthropicMsgToResponsesItems(m AnthropicMessage) ([]ResponsesInputItem, error) {
switch m.Role {
case "user":
return anthropicUserToResponses(m.Content)
case "assistant":
return anthropicAssistantToResponses(m.Content)
default:
return anthropicUserToResponses(m.Content)
}
}
// anthropicUserToResponses handles an Anthropic user message. Content can be a
// plain string or an array of blocks. tool_result blocks are extracted into
// function_call_output items. Image blocks are converted to input_image parts.
func anthropicUserToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err != nil {
return nil, err
}
var out []ResponsesInputItem
var toolResultImageParts []ResponsesContentPart
// Extract tool_result blocks → function_call_output items.
// Images inside tool_results are extracted separately because the
// Responses API function_call_output.output only accepts strings.
for _, b := range blocks {
if b.Type != "tool_result" {
continue
}
outputText, imageParts := convertToolResultOutput(b)
out = append(out, ResponsesInputItem{
Type: "function_call_output",
CallID: toResponsesCallID(b.ToolUseID),
Output: outputText,
})
toolResultImageParts = append(toolResultImageParts, imageParts...)
}
// Remaining text + image blocks → user message with content parts.
// Also include images extracted from tool_results so the model can see them.
var parts []ResponsesContentPart
for _, b := range blocks {
switch b.Type {
case "text":
if b.Text != "" {
parts = append(parts, ResponsesContentPart{Type: "input_text", Text: b.Text})
}
case "image":
if uri := anthropicImageToDataURI(b.Source); uri != "" {
parts = append(parts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
}
}
}
parts = append(parts, toolResultImageParts...)
if len(parts) > 0 {
content, err := json.Marshal(parts)
if err != nil {
return nil, err
}
out = append(out, ResponsesInputItem{Role: "user", Content: content})
}
return out, nil
}
// anthropicAssistantToResponses handles an Anthropic assistant message.
// Text content → assistant message with output_text parts.
// tool_use blocks → function_call items.
// thinking blocks → ignored (OpenAI doesn't accept them as input).
func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, error) {
// Try plain string.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "assistant", Content: partsJSON}}, nil
}
var blocks []AnthropicContentBlock
if err := json.Unmarshal(raw, &blocks); err != nil {
return nil, err
}
var items []ResponsesInputItem
// Text content → assistant message with output_text content parts.
text := extractAnthropicTextFromBlocks(blocks)
if text != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: text}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
}
// tool_use → function_call items.
for _, b := range blocks {
if b.Type != "tool_use" {
continue
}
args := "{}"
if len(b.Input) > 0 {
args = string(b.Input)
}
fcID := toResponsesCallID(b.ID)
items = append(items, ResponsesInputItem{
Type: "function_call",
CallID: fcID,
Name: b.Name,
Arguments: args,
ID: fcID,
})
}
return items, nil
}
// toResponsesCallID converts an Anthropic tool ID (toolu_xxx / call_xxx) to a
// Responses API function_call ID that starts with "fc_".
func toResponsesCallID(id string) string {
if strings.HasPrefix(id, "fc_") {
return id
}
return "fc_" + id
}
// fromResponsesCallID reverses toResponsesCallID, stripping the "fc_" prefix
// that was added during request conversion.
func fromResponsesCallID(id string) string {
if after, ok := strings.CutPrefix(id, "fc_"); ok {
// Only strip if the remainder doesn't look like it was already "fc_" prefixed.
// E.g. "fc_toolu_xxx" → "toolu_xxx", "fc_call_xxx" → "call_xxx"
if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") {
return after
}
}
return id
}
// anthropicImageToDataURI converts an AnthropicImageSource to a data URI string.
// Returns "" if the source is nil or has no data.
func anthropicImageToDataURI(src *AnthropicImageSource) string {
if src == nil || src.Data == "" {
return ""
}
mediaType := src.MediaType
if mediaType == "" {
mediaType = "image/png"
}
return "data:" + mediaType + ";base64," + src.Data
}
// convertToolResultOutput extracts text and image content from a tool_result
// block. Returns the text as a string for the function_call_output Output
// field, plus any image parts that must be sent in a separate user message
// (the Responses API output field only accepts strings).
func convertToolResultOutput(b AnthropicContentBlock) (string, []ResponsesContentPart) {
if len(b.Content) == 0 {
return "(empty)", nil
}
// Try plain string content.
var s string
if err := json.Unmarshal(b.Content, &s); err == nil {
if s == "" {
s = "(empty)"
}
return s, nil
}
// Array of content blocks — may contain text and/or images.
var inner []AnthropicContentBlock
if err := json.Unmarshal(b.Content, &inner); err != nil {
return "(empty)", nil
}
// Separate text (for function_call_output) from images (for user message).
var textParts []string
var imageParts []ResponsesContentPart
for _, ib := range inner {
switch ib.Type {
case "text":
if ib.Text != "" {
textParts = append(textParts, ib.Text)
}
case "image":
if uri := anthropicImageToDataURI(ib.Source); uri != "" {
imageParts = append(imageParts, ResponsesContentPart{Type: "input_image", ImageURL: uri})
}
}
}
text := strings.Join(textParts, "\n\n")
if text == "" {
text = "(empty)"
}
return text, imageParts
}
// extractAnthropicTextFromBlocks joins all text blocks, ignoring thinking/
// tool_use/tool_result blocks.
func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
var parts []string
for _, b := range blocks {
if b.Type == "text" && b.Text != "" {
parts = append(parts, b.Text)
}
}
return strings.Join(parts, "\n\n")
}
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
// OpenAI Responses API effort levels.
//
// low → low
// medium → high
// high → xhigh
func mapAnthropicEffortToResponses(effort string) string {
switch effort {
case "medium":
return "high"
case "high":
return "xhigh"
default:
return effort // "low" and any unknown values pass through unchanged
}
}
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
// Responses API tools. Server-side tools like web_search are mapped to their
// OpenAI equivalents; regular tools become function tools.
func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
var out []ResponsesTool
for _, t := range tools {
// Anthropic server tools like "web_search_20250305" → OpenAI {"type":"web_search"}
if strings.HasPrefix(t.Type, "web_search") {
out = append(out, ResponsesTool{Type: "web_search"})
continue
}
out = append(out, ResponsesTool{
Type: "function",
Name: t.Name,
Description: t.Description,
Parameters: t.InputSchema,
})
}
return out
}

View File

@@ -0,0 +1,516 @@
package apicompat
import (
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: ResponsesResponse → AnthropicResponse
// ---------------------------------------------------------------------------
// ResponsesToAnthropic converts a Responses API response directly into an
// Anthropic Messages response. Reasoning output items are mapped to thinking
// blocks; function_call items become tool_use blocks.
func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicResponse {
out := &AnthropicResponse{
ID: resp.ID,
Type: "message",
Role: "assistant",
Model: model,
}
var blocks []AnthropicContentBlock
for _, item := range resp.Output {
switch item.Type {
case "reasoning":
summaryText := ""
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
summaryText += s.Text
}
}
if summaryText != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "thinking",
Thinking: summaryText,
})
}
case "message":
for _, part := range item.Content {
if part.Type == "output_text" && part.Text != "" {
blocks = append(blocks, AnthropicContentBlock{
Type: "text",
Text: part.Text,
})
}
}
case "function_call":
blocks = append(blocks, AnthropicContentBlock{
Type: "tool_use",
ID: fromResponsesCallID(item.CallID),
Name: item.Name,
Input: json.RawMessage(item.Arguments),
})
case "web_search_call":
toolUseID := "srvtoolu_" + item.ID
query := ""
if item.Action != nil {
query = item.Action.Query
}
inputJSON, _ := json.Marshal(map[string]string{"query": query})
blocks = append(blocks, AnthropicContentBlock{
Type: "server_tool_use",
ID: toolUseID,
Name: "web_search",
Input: inputJSON,
})
emptyResults, _ := json.Marshal([]struct{}{})
blocks = append(blocks, AnthropicContentBlock{
Type: "web_search_tool_result",
ToolUseID: toolUseID,
Content: emptyResults,
})
}
}
if len(blocks) == 0 {
blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""})
}
out.Content = blocks
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
out.Usage = AnthropicUsage{
InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil {
out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
}
}
return out
}
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
switch status {
case "incomplete":
if details != nil && details.Reason == "max_output_tokens" {
return "max_tokens"
}
return "end_turn"
case "completed":
if len(blocks) > 0 && blocks[len(blocks)-1].Type == "tool_use" {
return "tool_use"
}
return "end_turn"
default:
return "end_turn"
}
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
// ResponsesEventToAnthropicState tracks state for converting a sequence of
// Responses SSE events directly into Anthropic SSE events.
type ResponsesEventToAnthropicState struct {
MessageStartSent bool
MessageStopSent bool
ContentBlockIndex int
ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use"
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int
InputTokens int
OutputTokens int
CacheReadInputTokens int
ResponseID string
Model string
Created int64
}
// NewResponsesEventToAnthropicState returns an initialised stream state.
func NewResponsesEventToAnthropicState() *ResponsesEventToAnthropicState {
return &ResponsesEventToAnthropicState{
OutputIndexToBlockIdx: make(map[int]int),
Created: time.Now().Unix(),
}
}
// ResponsesEventToAnthropicEvents converts a single Responses SSE event into
// zero or more Anthropic SSE events, updating state as it goes.
func ResponsesEventToAnthropicEvents(
evt *ResponsesStreamEvent,
state *ResponsesEventToAnthropicState,
) []AnthropicStreamEvent {
switch evt.Type {
case "response.created":
return resToAnthHandleCreated(evt, state)
case "response.output_item.added":
return resToAnthHandleOutputItemAdded(evt, state)
case "response.output_text.delta":
return resToAnthHandleTextDelta(evt, state)
case "response.output_text.done":
return resToAnthHandleBlockDone(state)
case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done":
return resToAnthHandleBlockDone(state)
case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta":
return resToAnthHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done":
return resToAnthHandleBlockDone(state)
case "response.completed", "response.incomplete", "response.failed":
return resToAnthHandleCompleted(evt, state)
default:
return nil
}
}
// FinalizeResponsesAnthropicStream emits synthetic termination events if the
// stream ended without a proper completion event.
func FinalizeResponsesAnthropicStream(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if !state.MessageStartSent || state.MessageStopSent {
return nil
}
var events []AnthropicStreamEvent
events = append(events, closeCurrentBlock(state)...)
events = append(events,
AnthropicStreamEvent{
Type: "message_delta",
Delta: &AnthropicDelta{
StopReason: "end_turn",
},
Usage: &AnthropicUsage{
InputTokens: state.InputTokens,
OutputTokens: state.OutputTokens,
CacheReadInputTokens: state.CacheReadInputTokens,
},
},
AnthropicStreamEvent{Type: "message_stop"},
)
state.MessageStopSent = true
return events
}
// ResponsesAnthropicEventToSSE formats an AnthropicStreamEvent as an SSE line pair.
func ResponsesAnthropicEventToSSE(evt AnthropicStreamEvent) (string, error) {
data, err := json.Marshal(evt)
if err != nil {
return "", err
}
return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil
}
// --- internal handlers ---
func resToAnthHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Response != nil {
state.ResponseID = evt.Response.ID
// Only use upstream model if no override was set (e.g. originalModel)
if state.Model == "" {
state.Model = evt.Response.Model
}
}
if state.MessageStartSent {
return nil
}
state.MessageStartSent = true
return []AnthropicStreamEvent{{
Type: "message_start",
Message: &AnthropicResponse{
ID: state.ResponseID,
Type: "message",
Role: "assistant",
Content: []AnthropicContentBlock{},
Model: state.Model,
Usage: AnthropicUsage{
InputTokens: 0,
OutputTokens: 0,
},
},
}}
}
func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Item == nil {
return nil
}
switch evt.Item.Type {
case "function_call":
var events []AnthropicStreamEvent
events = append(events, closeCurrentBlock(state)...)
idx := state.ContentBlockIndex
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use"
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx,
ContentBlock: &AnthropicContentBlock{
Type: "tool_use",
ID: fromResponsesCallID(evt.Item.CallID),
Name: evt.Item.Name,
Input: json.RawMessage("{}"),
},
})
return events
case "reasoning":
var events []AnthropicStreamEvent
events = append(events, closeCurrentBlock(state)...)
idx := state.ContentBlockIndex
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "thinking"
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx,
ContentBlock: &AnthropicContentBlock{
Type: "thinking",
Thinking: "",
},
})
return events
case "message":
return nil
}
return nil
}
func resToAnthHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
}
var events []AnthropicStreamEvent
if !state.ContentBlockOpen || state.CurrentBlockType != "text" {
events = append(events, closeCurrentBlock(state)...)
idx := state.ContentBlockIndex
state.ContentBlockOpen = true
state.CurrentBlockType = "text"
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx,
ContentBlock: &AnthropicContentBlock{
Type: "text",
Text: "",
},
})
}
idx := state.ContentBlockIndex
events = append(events, AnthropicStreamEvent{
Type: "content_block_delta",
Index: &idx,
Delta: &AnthropicDelta{
Type: "text_delta",
Text: evt.Delta,
},
})
return events
}
func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
}
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
}
return []AnthropicStreamEvent{{
Type: "content_block_delta",
Index: &blockIdx,
Delta: &AnthropicDelta{
Type: "input_json_delta",
PartialJSON: evt.Delta,
},
}}
}
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
}
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
}
return []AnthropicStreamEvent{{
Type: "content_block_delta",
Index: &blockIdx,
Delta: &AnthropicDelta{
Type: "thinking_delta",
Thinking: evt.Delta,
},
}}
}
func resToAnthHandleBlockDone(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if !state.ContentBlockOpen {
return nil
}
return closeCurrentBlock(state)
}
func resToAnthHandleOutputItemDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Item == nil {
return nil
}
// Handle web_search_call → synthesize server_tool_use + web_search_tool_result blocks.
if evt.Item.Type == "web_search_call" && evt.Item.Status == "completed" {
return resToAnthHandleWebSearchDone(evt, state)
}
if state.ContentBlockOpen {
return closeCurrentBlock(state)
}
return nil
}
// resToAnthHandleWebSearchDone converts an OpenAI web_search_call output item
// into Anthropic server_tool_use + web_search_tool_result content block pairs.
// This allows Claude Code to count the searches performed.
func resToAnthHandleWebSearchDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
var events []AnthropicStreamEvent
events = append(events, closeCurrentBlock(state)...)
toolUseID := "srvtoolu_" + evt.Item.ID
query := ""
if evt.Item.Action != nil {
query = evt.Item.Action.Query
}
inputJSON, _ := json.Marshal(map[string]string{"query": query})
// Emit server_tool_use block (start + stop).
idx1 := state.ContentBlockIndex
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx1,
ContentBlock: &AnthropicContentBlock{
Type: "server_tool_use",
ID: toolUseID,
Name: "web_search",
Input: inputJSON,
},
})
events = append(events, AnthropicStreamEvent{
Type: "content_block_stop",
Index: &idx1,
})
state.ContentBlockIndex++
// Emit web_search_tool_result block (start + stop).
// Content is empty because OpenAI does not expose individual search results;
// the model consumes them internally and produces text output.
emptyResults, _ := json.Marshal([]struct{}{})
idx2 := state.ContentBlockIndex
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx2,
ContentBlock: &AnthropicContentBlock{
Type: "web_search_tool_result",
ToolUseID: toolUseID,
Content: emptyResults,
},
})
events = append(events, AnthropicStreamEvent{
Type: "content_block_stop",
Index: &idx2,
})
state.ContentBlockIndex++
return events
}
func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if state.MessageStopSent {
return nil
}
var events []AnthropicStreamEvent
events = append(events, closeCurrentBlock(state)...)
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
state.InputTokens = evt.Response.Usage.InputTokens
state.OutputTokens = evt.Response.Usage.OutputTokens
if evt.Response.Usage.InputTokensDetails != nil {
state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
}
}
switch evt.Response.Status {
case "incomplete":
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
stopReason = "max_tokens"
}
case "completed":
if state.ContentBlockIndex > 0 && state.CurrentBlockType == "tool_use" {
stopReason = "tool_use"
}
}
}
events = append(events,
AnthropicStreamEvent{
Type: "message_delta",
Delta: &AnthropicDelta{
StopReason: stopReason,
},
Usage: &AnthropicUsage{
InputTokens: state.InputTokens,
OutputTokens: state.OutputTokens,
CacheReadInputTokens: state.CacheReadInputTokens,
},
},
AnthropicStreamEvent{Type: "message_stop"},
)
state.MessageStopSent = true
return events
}
func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if !state.ContentBlockOpen {
return nil
}
idx := state.ContentBlockIndex
state.ContentBlockOpen = false
state.ContentBlockIndex++
return []AnthropicStreamEvent{{
Type: "content_block_stop",
Index: &idx,
}}
}

View File

@@ -0,0 +1,338 @@
// Package apicompat provides type definitions and conversion utilities for
// translating between Anthropic Messages and OpenAI Responses API formats.
// It enables multi-protocol support so that clients using different API
// formats can be served through a unified gateway.
package apicompat
import "encoding/json"
// ---------------------------------------------------------------------------
// Anthropic Messages API types
// ---------------------------------------------------------------------------
// AnthropicRequest is the request body for POST /v1/messages.
type AnthropicRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
Messages []AnthropicMessage `json:"messages"`
Tools []AnthropicTool `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
StopSeqs []string `json:"stop_sequences,omitempty"`
Thinking *AnthropicThinking `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
}
// AnthropicOutputConfig controls output generation parameters.
type AnthropicOutputConfig struct {
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
}
// AnthropicThinking configures extended thinking in the Anthropic API.
type AnthropicThinking struct {
Type string `json:"type"` // "enabled" | "adaptive" | "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"` // max thinking tokens
}
// AnthropicMessage is a single message in the Anthropic conversation.
type AnthropicMessage struct {
Role string `json:"role"` // "user" | "assistant"
Content json.RawMessage `json:"content"`
}
// AnthropicContentBlock is one block inside a message's content array.
type AnthropicContentBlock struct {
Type string `json:"type"`
// type=text
Text string `json:"text,omitempty"`
// type=thinking
Thinking string `json:"thinking,omitempty"`
// type=image
Source *AnthropicImageSource `json:"source,omitempty"`
// type=tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
// type=tool_result
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"` // string or []AnthropicContentBlock
IsError bool `json:"is_error,omitempty"`
}
// AnthropicImageSource describes the source data for an image content block.
type AnthropicImageSource struct {
Type string `json:"type"` // "base64"
MediaType string `json:"media_type"`
Data string `json:"data"`
}
// AnthropicTool describes a tool available to the model.
type AnthropicTool struct {
Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
}
// AnthropicResponse is the non-streaming response from POST /v1/messages.
type AnthropicResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Content []AnthropicContentBlock `json:"content"`
Model string `json:"model"`
StopReason string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence,omitempty"`
Usage AnthropicUsage `json:"usage"`
}
// AnthropicUsage holds token counts in Anthropic format.
type AnthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}
// ---------------------------------------------------------------------------
// Anthropic SSE event types
// ---------------------------------------------------------------------------
// AnthropicStreamEvent is a single SSE event in the Anthropic streaming protocol.
type AnthropicStreamEvent struct {
Type string `json:"type"`
// message_start
Message *AnthropicResponse `json:"message,omitempty"`
// content_block_start
Index *int `json:"index,omitempty"`
ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"`
// content_block_delta
Delta *AnthropicDelta `json:"delta,omitempty"`
// message_delta
Usage *AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicDelta carries incremental content in streaming events.
type AnthropicDelta struct {
Type string `json:"type,omitempty"` // "text_delta" | "input_json_delta" | "thinking_delta" | "signature_delta"
// text_delta
Text string `json:"text,omitempty"`
// input_json_delta
PartialJSON string `json:"partial_json,omitempty"`
// thinking_delta
Thinking string `json:"thinking,omitempty"`
// signature_delta
Signature string `json:"signature,omitempty"`
// message_delta fields
StopReason string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
}
// ---------------------------------------------------------------------------
// OpenAI Responses API types
// ---------------------------------------------------------------------------
// ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct {
Model string `json:"model"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []ResponsesTool `json:"tools,omitempty"`
Include []string `json:"include,omitempty"`
Store *bool `json:"store,omitempty"`
Reasoning *ResponsesReasoning `json:"reasoning,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ResponsesReasoning configures reasoning effort in the Responses API.
type ResponsesReasoning struct {
Effort string `json:"effort"` // "low" | "medium" | "high"
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
}
// ResponsesInputItem is one item in the Responses API input array.
// The Type field determines which other fields are populated.
type ResponsesInputItem struct {
// Common
Type string `json:"type,omitempty"` // "" for role-based messages
// Role-based messages (system/user/assistant)
Role string `json:"role,omitempty"`
Content json.RawMessage `json:"content,omitempty"` // string or []ResponsesContentPart
// type=function_call
CallID string `json:"call_id,omitempty"`
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
ID string `json:"id,omitempty"`
// type=function_call_output
Output string `json:"output,omitempty"`
}
// ResponsesContentPart is a typed content part in a Responses message.
type ResponsesContentPart struct {
Type string `json:"type"` // "input_text" | "output_text" | "input_image"
Text string `json:"text,omitempty"`
ImageURL string `json:"image_url,omitempty"` // data URI for input_image
}
// ResponsesTool describes a tool in the Responses API.
type ResponsesTool struct {
Type string `json:"type"` // "function" | "web_search" | "local_shell" etc.
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// ResponsesResponse is the non-streaming response from POST /v1/responses.
type ResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"` // "response"
Model string `json:"model"`
Status string `json:"status"` // "completed" | "incomplete" | "failed"
Output []ResponsesOutput `json:"output"`
Usage *ResponsesUsage `json:"usage,omitempty"`
// incomplete_details is present when status="incomplete"
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details,omitempty"`
// Error is present when status="failed"
Error *ResponsesError `json:"error,omitempty"`
}
// ResponsesError describes an error in a failed response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails explains why a response is incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"` // "max_output_tokens" | "content_filter"
}
// ResponsesOutput is one output item in a Responses API response.
type ResponsesOutput struct {
Type string `json:"type"` // "message" | "reasoning" | "function_call" | "web_search_call"
// type=message
ID string `json:"id,omitempty"`
Role string `json:"role,omitempty"`
Content []ResponsesContentPart `json:"content,omitempty"`
Status string `json:"status,omitempty"`
// type=reasoning
EncryptedContent string `json:"encrypted_content,omitempty"`
Summary []ResponsesSummary `json:"summary,omitempty"`
// type=function_call
CallID string `json:"call_id,omitempty"`
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
// type=web_search_call
Action *WebSearchAction `json:"action,omitempty"`
}
// WebSearchAction describes the search action in a web_search_call output item.
type WebSearchAction struct {
Type string `json:"type,omitempty"` // "search"
Query string `json:"query,omitempty"` // primary search query
}
// ResponsesSummary is a summary text block inside a reasoning output.
type ResponsesSummary struct {
Type string `json:"type"` // "summary_text"
Text string `json:"text"`
}
// ResponsesUsage holds token counts in Responses API format.
type ResponsesUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
// Optional detailed breakdown
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
}
// ResponsesInputTokensDetails breaks down input token usage.
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
}
// ResponsesOutputTokensDetails breaks down output token usage.
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
}
// ---------------------------------------------------------------------------
// Responses SSE event types
// ---------------------------------------------------------------------------
// ResponsesStreamEvent is a single SSE event in the Responses streaming protocol.
// The Type field corresponds to the "type" in the JSON payload.
type ResponsesStreamEvent struct {
Type string `json:"type"`
// response.created / response.completed / response.failed / response.incomplete
Response *ResponsesResponse `json:"response,omitempty"`
// response.output_item.added / response.output_item.done
Item *ResponsesOutput `json:"item,omitempty"`
// response.output_text.delta / response.output_text.done
OutputIndex int `json:"output_index,omitempty"`
ContentIndex int `json:"content_index,omitempty"`
Delta string `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
ItemID string `json:"item_id,omitempty"`
// response.function_call_arguments.delta / done
CallID string `json:"call_id,omitempty"`
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
// response.reasoning_summary_text.delta / done
// Reuses Text/Delta fields above, SummaryIndex identifies which summary part
SummaryIndex int `json:"summary_index,omitempty"`
// error event fields
Code string `json:"code,omitempty"`
Param string `json:"param,omitempty"`
// Sequence number for ordering events
SequenceNumber int `json:"sequence_number,omitempty"`
}
// ---------------------------------------------------------------------------
// Shared constants
// ---------------------------------------------------------------------------
// minMaxOutputTokens is the floor for max_output_tokens in a Responses request.
// Very small values may cause upstream API errors, so we enforce a minimum.
const minMaxOutputTokens = 128

View File

@@ -16,7 +16,7 @@ const (
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
// 这些 token 是客户端特有的,不应透传给上游 API。
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
var DroppedBetas = []string{BetaFastMode}
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming

View File

@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
}
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
// official Codex client family request.
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
}
func normalizeCodexClientHeader(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}

View File

@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
})
}
}
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
tests := []struct {
name string
ua string
originator string
want bool
}{
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
}
})
}
}

View File

@@ -659,13 +659,10 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
if err != nil {
return err
}
// 清除临时不可调度状态,重置 401 升级链
_, _ = r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL
WHERE id = $1 AND deleted_at IS NULL
`, id)
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -925,6 +922,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -1040,6 +1038,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}
@@ -1676,13 +1675,47 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
return r.accountsToService(ctx, accounts)
}
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
// 日/周额度在周期过期时自动重置为 0 再递增。
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
rows, err := r.sql.QueryContext(ctx,
`UPDATE accounts SET extra = jsonb_set(
COALESCE(extra, '{}'::jsonb),
'{quota_used}',
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
-- 总额度:始终递增
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
-- 日额度:仅在 quota_daily_limit > 0 时处理
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
@@ -1704,7 +1737,7 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
return err
}
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
// 任一维度配额刚超限时触发调度快照刷新
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
@@ -1713,14 +1746,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
return nil
}
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx,
`UPDATE accounts SET extra = jsonb_set(
COALESCE(extra, '{}'::jsonb),
'{quota_used}',
'0'::jsonb
), updated_at = NOW()
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {

View File

@@ -558,6 +558,26 @@ func (s *AccountRepoSuite) TestSetError() {
s.Require().Equal("something went wrong", got.ErrorMessage)
}
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-clear-err",
Status: service.StatusError,
ErrorMessage: "temporary error",
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusActive, got.Status)
s.Require().Empty(got.ErrorMessage)
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {

View File

@@ -24,6 +24,7 @@ func (r *announcementRepository) Create(ctx context.Context, a *service.Announce
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetNotifyMode(a.NotifyMode).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
@@ -64,6 +65,7 @@ func (r *announcementRepository) Update(ctx context.Context, a *service.Announce
SetTitle(a.Title).
SetContent(a.Content).
SetStatus(a.Status).
SetNotifyMode(a.NotifyMode).
SetTargeting(a.Targeting)
if a.StartsAt != nil {
@@ -169,17 +171,18 @@ func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
return nil
}
return &service.Announcement{
ID: m.ID,
Title: m.Title,
Content: m.Content,
Status: m.Status,
Targeting: m.Targeting,
StartsAt: m.StartsAt,
EndsAt: m.EndsAt,
CreatedBy: m.CreatedBy,
UpdatedBy: m.UpdatedBy,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
ID: m.ID,
Title: m.Title,
Content: m.Content,
Status: m.Status,
NotifyMode: m.NotifyMode,
Targeting: m.Targeting,
StartsAt: m.StartsAt,
EndsAt: m.EndsAt,
CreatedBy: m.CreatedBy,
UpdatedBy: m.UpdatedBy,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}

View File

@@ -165,6 +165,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSupportedModelScopes,
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
)
}).
Only(ctx)
@@ -470,12 +472,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = usage_5h + $1,
usage_1d = usage_1d + $1,
usage_7d = usage_7d + $1,
window_5h_start = COALESCE(window_5h_start, NOW()),
window_1d_start = COALESCE(window_1d_start, NOW()),
window_7d_start = COALESCE(window_7d_start, NOW()),
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`,
cost, id)
@@ -489,9 +491,9 @@ func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64)
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
@@ -619,6 +621,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
MCPXMLInject: g.McpXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
SortOrder: g.SortOrder,
AllowMessagesDispatch: g.AllowMessagesDispatch,
DefaultMappedModel: g.DefaultMappedModel,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -89,6 +89,10 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
_ = client.Close()
return nil, nil, err
}
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
_ = client.Close()
return nil, nil, err
}
}
return client, drv.DB(), nil

View File

@@ -59,7 +59,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -125,7 +127,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -20,16 +20,16 @@ func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanReposit
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, auto_recover, next_run_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
return scanPlan(row)
}
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE id = $1
`, id)
return scanPlan(row)
@@ -37,7 +37,7 @@ func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*s
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans WHERE account_id = $1
ORDER BY created_at DESC
`, accountID)
@@ -50,7 +50,7 @@ func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accou
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
SELECT id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
FROM scheduled_test_plans
WHERE enabled = true AND next_run_at <= $1
ORDER BY next_run_at ASC
@@ -65,10 +65,10 @@ func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
row := r.db.QueryRowContext(ctx, `
UPDATE scheduled_test_plans
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, auto_recover = $6, next_run_at = $7, updated_at = NOW()
WHERE id = $1
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, auto_recover, last_run_at, next_run_at, created_at, updated_at
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.AutoRecover, plan.NextRunAt)
return scanPlan(row)
}
@@ -162,7 +162,7 @@ type scannable interface {
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
p := &service.ScheduledTestPlan{}
if err := row.Scan(
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults, &p.AutoRecover,
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, err

View File

@@ -0,0 +1,55 @@
package repository
import (
"context"
"fmt"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/setting"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
simpleModeAdminConcurrencyUpgradeKey = "simple_mode_admin_concurrency_upgraded_30"
simpleModeLegacyAdminConcurrency = 5
simpleModeTargetAdminConcurrency = 30
)
func ensureSimpleModeAdminConcurrency(ctx context.Context, client *dbent.Client) error {
if client == nil {
return fmt.Errorf("nil ent client")
}
upgraded, err := client.Setting.Query().Where(setting.KeyEQ(simpleModeAdminConcurrencyUpgradeKey)).Exist(ctx)
if err != nil {
return fmt.Errorf("check admin concurrency upgrade marker: %w", err)
}
if upgraded {
return nil
}
if _, err := client.User.Update().
Where(
dbuser.RoleEQ(service.RoleAdmin),
dbuser.ConcurrencyEQ(simpleModeLegacyAdminConcurrency),
).
SetConcurrency(simpleModeTargetAdminConcurrency).
Save(ctx); err != nil {
return fmt.Errorf("upgrade simple mode admin concurrency: %w", err)
}
now := time.Now()
if err := client.Setting.Create().
SetKey(simpleModeAdminConcurrencyUpgradeKey).
SetValue(now.Format(time.RFC3339)).
SetUpdatedAt(now).
OnConflictColumns(setting.FieldKey).
UpdateNewValues().
Exec(ctx); err != nil {
return fmt.Errorf("persist admin concurrency upgrade marker: %w", err)
}
return nil
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
@@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
@@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any
@@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.ImageCount,
imageSize,
mediaType,
serviceTier,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
@@ -2505,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageCount int
imageSize sql.NullString
mediaType sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
cacheTTLOverridden bool
createdAt time.Time
@@ -2544,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&imageCount,
&imageSize,
&mediaType,
&serviceTier,
&reasoningEffort,
&cacheTTLOverridden,
&createdAt,
@@ -2614,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if mediaType.Valid {
log.MediaType = &mediaType.String
}
if serviceTier.Valid {
log.ServiceTier = &serviceTier.String
}
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}

View File

@@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.ImageCount,
sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
log.CacheTTLOverridden,
createdAt,
@@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
require.NoError(t, err)
require.True(t, inserted)
require.Equal(t, int64(99), log.ID)
require.Nil(t, log.ServiceTier)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
serviceTier := "priority"
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-service-tier",
Model: "gpt-5.4",
ServiceTier: &serviceTier,
CreatedAt: createdAt,
}
mock.ExpectQuery("INSERT INTO usage_logs").
WithArgs(
log.UserID,
log.APIKeyID,
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
log.RateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
int16(service.RequestTypeSync),
false,
false,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.ImageCount,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
log.CacheTTLOverridden,
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
inserted, err := repo.Create(context.Background(), log)
require.NoError(t, err)
require.True(t, inserted)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
@@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
0,
sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.NotNil(t, log.ServiceTier)
require.Equal(t, "priority", *log.ServiceTier)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
@@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
0,
sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.NotNil(t, log.ServiceTier)
require.Equal(t, "flex", *log.ServiceTier)
require.Equal(t, service.RequestTypeStream, log.RequestType)
require.True(t, log.Stream)
require.False(t, log.OpenAIWSMode)
})
t.Run("service_tier_is_scanned", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(3),
int64(12),
int64(22),
int64(32),
sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4",
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
int16(service.BillingTypeBalance),
int16(service.RequestTypeSync),
false,
false,
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.NotNil(t, log.ServiceTier)
require.Equal(t, "priority", *log.ServiceTier)
})
}

View File

@@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"allow_messages_dispatch": false,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}

View File

@@ -244,6 +244,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
@@ -392,6 +393,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
// 请求整流器配置
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)

View File

@@ -43,12 +43,33 @@ func RegisterGatewayRoutes(
gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic)
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
// /v1/messages: auto-route based on group platform
gateway.POST("/messages", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
h.OpenAIGateway.Messages(c)
return
}
h.Gateway.Messages(c)
})
// /v1/messages/count_tokens: OpenAI groups get 404
gateway.POST("/messages/count_tokens", func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
c.JSON(http.StatusNotFound, gin.H{
"type": "error",
"error": gin.H{
"type": "not_found_error",
"message": "Token counting is not supported for this platform",
},
})
return
}
h.Gateway.CountTokens(c)
})
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口OpenAI 仅支持 Responses API避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
@@ -77,6 +98,7 @@ func RegisterGatewayRoutes(
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表
@@ -132,3 +154,12 @@ func RegisterGatewayRoutes(
// Sora 媒体代理(签名 URL无需 API Key
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
}
// getGroupPlatform extracts the group platform from the API Key stored in context.
func getGroupPlatform(c *gin.Context) string {
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey.Group == nil {
return ""
}
return apiKey.Group.Platform
}

View File

@@ -0,0 +1,51 @@
package routes
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func newGatewayRoutesTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
RegisterGatewayRoutes(
router,
&handler.Handlers{
Gateway: &handler.GatewayHandler{},
OpenAIGateway: &handler.OpenAIGatewayHandler{},
SoraGateway: &handler.SoraGatewayHandler{},
},
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
c.Next()
}),
nil,
nil,
nil,
nil,
&config.Config{},
)
return router
}
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
router := newGatewayRoutesTestRouter()
for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
}
}

View File

@@ -647,6 +647,75 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
return false
}
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["pool_mode"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
const (
defaultPoolModeRetryCount = 3
maxPoolModeRetryCount = 10
)
// GetPoolModeRetryCount 返回池模式同账号重试次数。
// 未配置或配置非法时回退为默认值 3小于 0 按 0 处理;过大则截断到 10。
func (a *Account) GetPoolModeRetryCount() int {
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
return defaultPoolModeRetryCount
}
raw, ok := a.Credentials["pool_mode_retry_count"]
if !ok || raw == nil {
return defaultPoolModeRetryCount
}
count := parsePoolModeRetryCount(raw)
if count < 0 {
return 0
}
if count > maxPoolModeRetryCount {
return maxPoolModeRetryCount
}
return count
}
func parsePoolModeRetryCount(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return defaultPoolModeRetryCount
}
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
func isPoolModeRetryableStatus(statusCode int) bool {
switch statusCode {
case 401, 403, 429:
return true
default:
return false
}
}
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
@@ -1134,33 +1203,97 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
// 返回 0 表示未启用
func (a *Account) GetQuotaLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
return a.getExtraFloat64("quota_limit")
}
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
func (a *Account) GetQuotaUsed() float64 {
return a.getExtraFloat64("quota_used")
}
// GetQuotaDailyLimit 获取日额度限制美元0 表示未启用
func (a *Account) GetQuotaDailyLimit() float64 {
return a.getExtraFloat64("quota_daily_limit")
}
// GetQuotaDailyUsed 获取当日已用额度(美元)
func (a *Account) GetQuotaDailyUsed() float64 {
return a.getExtraFloat64("quota_daily_used")
}
// GetQuotaWeeklyLimit 获取周额度限制美元0 表示未启用
func (a *Account) GetQuotaWeeklyLimit() float64 {
return a.getExtraFloat64("quota_weekly_limit")
}
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
func (a *Account) GetQuotaWeeklyUsed() float64 {
return a.getExtraFloat64("quota_weekly_used")
}
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
func (a *Account) getExtraFloat64(key string) float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_used"]; ok {
if v, ok := a.Extra[key]; ok {
return parseExtraFloat64(v)
}
return 0
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
func (a *Account) IsQuotaExceeded() bool {
limit := a.GetQuotaLimit()
if limit <= 0 {
return false
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
func (a *Account) getExtraTime(key string) time.Time {
if a.Extra == nil {
return time.Time{}
}
return a.GetQuotaUsed() >= limit
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
return t
}
if t, err := time.Parse(time.RFC3339, s); err == nil {
return t
}
}
}
return time.Time{}
}
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
func (a *Account) HasAnyQuotaLimit() bool {
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
}
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur是否已过期
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
if periodStart.IsZero() {
return true // 从未使用过,视为过期(下次 increment 会初始化)
}
return time.Since(periodStart) >= dur
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true
func (a *Account) IsQuotaExceeded() bool {
// 总额度
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
return true
}
// 日额度(周期过期视为未超限,下次 increment 会重置)
if limit := a.GetQuotaDailyLimit(); limit > 0 {
start := a.getExtraTime("quota_daily_start")
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
return true
}
}
// 周额度
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
start := a.getExtraTime("quota_weekly_start")
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
return true
}
}
return false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)

View File

@@ -0,0 +1,117 @@
//go:build unit
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetPoolModeRetryCount(t *testing.T) {
tests := []struct {
name string
account *Account
expected int
}{
{
name: "default_when_not_pool_mode",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{},
},
expected: defaultPoolModeRetryCount,
},
{
name: "default_when_missing_retry_count",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
expected: defaultPoolModeRetryCount,
},
{
name: "supports_float64_from_json_credentials",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": float64(5),
},
},
expected: 5,
},
{
name: "supports_json_number",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": json.Number("4"),
},
},
expected: 4,
},
{
name: "supports_string_value",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "2",
},
},
expected: 2,
},
{
name: "negative_value_is_clamped_to_zero",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": -1,
},
},
expected: 0,
},
{
name: "oversized_value_is_clamped_to_max",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": 99,
},
},
expected: maxPoolModeRetryCount,
},
{
name: "invalid_value_falls_back_to_default",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "oops",
},
},
expected: defaultPoolModeRetryCount,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
})
}
}

View File

@@ -68,9 +68,9 @@ type AccountRepository interface {
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
ResetQuotaUsed(ctx context.Context, id int64) error
}

View File

@@ -406,8 +406,27 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
defer func() { _ = resp.Body.Close() }()
if isOAuth && s.accountRepo != nil {
if updates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
account.RateLimitResetAt = resetAt
}
}
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if isOAuth && s.accountRepo != nil {
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
account.RateLimitResetAt = resetAt
}
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}

View File

@@ -0,0 +1,102 @@
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
r.updatedExtra = updates
return nil
}
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
r.rateLimitedID = id
r.rateLimitedAt = &resetAt
return nil
}
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
`))
resp.Header.Set("x-codex-primary-used-percent", "88")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
resp.Header.Set("x-codex-secondary-used-percent", "42")
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
resp.Header.Set("x-codex-secondary-window-minutes", "300")
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 89,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newSoraTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
resp.Header.Set("x-codex-secondary-used-percent", "100")
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
resp.Header.Set("x-codex-secondary-window-minutes", "300")
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, int64(88), repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.NotNil(t, account.RateLimitResetAt)
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
}
}

View File

@@ -1,17 +1,24 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"math/rand/v2"
"net/http"
"strings"
"sync"
"time"
httppool "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/singleflight"
)
type UsageLogRepository interface {
@@ -70,8 +77,10 @@ type accountWindowStatsBatchReader interface {
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据utilization, resets_at
// 同时支持缓存错误响应(负缓存),防止 429 等错误导致的重试风暴
type apiUsageCache struct {
response *ClaudeUsageResponse
err error // 非 nil 表示缓存的错误(负缓存)
timestamp time.Time
}
@@ -88,15 +97,21 @@ type antigravityUsageCache struct {
}
const (
apiCacheTTL = 3 * time.Minute
windowStatsCacheTTL = 1 * time.Minute
apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL429 等错误缓存 1 分钟
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
openAICodexProbeVersion = "0.104.0"
)
// UsageCache 封装账户使用量相关的缓存
type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
}
// NewUsageCache 创建 UsageCache 实例
@@ -224,6 +239,14 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth {
usage, err := s.getOpenAIUsage(ctx, account)
if err == nil {
s.tryClearRecoverableAccountError(ctx, account)
}
return usage, err
}
if account.Platform == PlatformGemini {
usage, err := s.getGeminiUsage(ctx, account)
if err == nil {
@@ -245,24 +268,65 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
// 1. 检查 API 缓存10 分钟)
// 1. 检查缓存(成功响应 3 分钟 / 错误响应 1 分钟)
if cached, ok := s.cache.apiCache.Load(accountID); ok {
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
apiResp = cache.response
if cache, ok := cached.(*apiUsageCache); ok {
age := time.Since(cache.timestamp)
if cache.err != nil && age < apiErrorCacheTTL {
// 负缓存命中:返回缓存的错误,避免重试风暴
return nil, cache.err
}
if cache.response != nil && age < apiCacheTTL {
apiResp = cache.response
}
}
}
// 2. 如果没有缓存,从 API 获取
// 2. 如果没有有效缓存,通过 singleflight 从 API 获取(防止并发击穿)
if apiResp == nil {
apiResp, err = s.fetchOAuthUsageRaw(ctx, account)
if err != nil {
return nil, err
// 随机延迟:打散多账号并发请求,避免同一时刻大量相同 TLS 指纹请求
// 触发上游反滥用检测。延迟范围 0~800ms仅在缓存未命中时生效。
jitter := time.Duration(rand.Int64N(int64(apiQueryMaxJitter)))
select {
case <-time.After(jitter):
case <-ctx.Done():
return nil, ctx.Err()
}
// 缓存 API 响应
s.cache.apiCache.Store(accountID, &apiUsageCache{
response: apiResp,
timestamp: time.Now(),
flightKey := fmt.Sprintf("usage:%d", accountID)
result, flightErr, _ := s.cache.apiFlight.Do(flightKey, func() (any, error) {
// 再次检查缓存(可能在等待 singleflight 期间被其他请求填充)
if cached, ok := s.cache.apiCache.Load(accountID); ok {
if cache, ok := cached.(*apiUsageCache); ok {
age := time.Since(cache.timestamp)
if cache.err != nil && age < apiErrorCacheTTL {
return nil, cache.err
}
if cache.response != nil && age < apiCacheTTL {
return cache.response, nil
}
}
}
resp, fetchErr := s.fetchOAuthUsageRaw(ctx, account)
if fetchErr != nil {
// 负缓存:缓存错误响应,防止后续请求重复触发 429
s.cache.apiCache.Store(accountID, &apiUsageCache{
err: fetchErr,
timestamp: time.Now(),
})
return nil, fetchErr
}
// 缓存成功响应
s.cache.apiCache.Store(accountID, &apiUsageCache{
response: resp,
timestamp: time.Now(),
})
return resp, nil
})
if flightErr != nil {
return nil, flightErr
}
apiResp, _ = result.(*ClaudeUsageResponse)
}
// 3. 构建 UsageInfo每次都重新计算 RemainingSeconds
@@ -288,6 +352,210 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{UpdatedAt: &now}
if account == nil {
return usage, nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
usage.FiveHour = progress
}
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
usage.SevenDay = progress
}
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
mergeAccountExtra(account, updates)
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
usage.FiveHour = progress
}
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
usage.SevenDay = progress
}
}
}
if s.usageLogRepo == nil {
return usage, nil
}
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
windowStats := windowStatsFromAccountStats(stats)
if hasMeaningfulWindowStats(windowStats) {
if usage.FiveHour == nil {
usage.FiveHour = &UsageProgress{Utilization: 0}
}
usage.FiveHour.WindowStats = windowStats
}
}
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
windowStats := windowStatsFromAccountStats(stats)
if hasMeaningfulWindowStats(windowStats) {
if usage.SevenDay == nil {
usage.SevenDay = &UsageProgress{Utilization: 0}
}
usage.SevenDay.WindowStats = windowStats
}
}
return usage, nil
}
func shouldRefreshOpenAICodexSnapshot(account *Account, usage *UsageInfo, now time.Time) bool {
if account == nil {
return false
}
if usage == nil {
return true
}
if usage.FiveHour == nil || usage.SevenDay == nil {
return true
}
if account.IsRateLimited() {
return true
}
return isOpenAICodexSnapshotStale(account, now)
}
func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
if account == nil || !account.IsOpenAIOAuth() || !account.IsOpenAIResponsesWebSocketV2Enabled() {
return false
}
if account.Extra == nil {
return true
}
raw, ok := account.Extra["codex_usage_updated_at"]
if !ok {
return true
}
ts, err := parseTime(fmt.Sprint(raw))
if err != nil {
return true
}
return now.Sub(ts) >= openAIProbeCacheTTL
}
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
if s == nil || s.cache == nil || accountID <= 0 {
return true
}
if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok {
if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL {
return false
}
}
s.cache.openAIProbeCache.Store(accountID, now)
return true
}
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
if account == nil || !account.IsOAuth() {
return nil, nil
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
}
modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("create openai probe request: %w", err)
}
req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("Originator", "codex_cli_rs")
req.Header.Set("Version", openAICodexProbeVersion)
req.Header.Set("User-Agent", codexCLIUserAgent)
if s.identityCache != nil {
if fp, fpErr := s.identityCache.GetFingerprint(reqCtx, account.ID); fpErr == nil && fp != nil && strings.TrimSpace(fp.UserAgent) != "" {
req.Header.Set("User-Agent", strings.TrimSpace(fp.UserAgent))
}
}
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
client, err := httppool.GetClient(httppool.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("build openai probe client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
updates, err := extractOpenAICodexProbeUpdates(resp)
if err != nil {
return nil, err
}
if len(updates) > 0 {
go func(accountID int64, updates map[string]any) {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}(account.ID, updates)
return updates, nil
}
return nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
if resp == nil {
return nil, nil
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
if len(updates) > 0 {
return updates, nil
}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil
}
func mergeAccountExtra(account *Account, updates map[string]any) {
if account == nil || len(updates) == 0 {
return
}
if account.Extra == nil {
account.Extra = make(map[string]any, len(updates))
}
for k, v := range updates {
account.Extra[k] = v
}
}
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
@@ -519,6 +787,72 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
}
}
func hasMeaningfulWindowStats(stats *WindowStats) bool {
if stats == nil {
return false
}
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
}
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
if len(extra) == 0 {
return nil
}
var (
usedPercentKey string
resetAfterKey string
resetAtKey string
)
switch window {
case "5h":
usedPercentKey = "codex_5h_used_percent"
resetAfterKey = "codex_5h_reset_after_seconds"
resetAtKey = "codex_5h_reset_at"
case "7d":
usedPercentKey = "codex_7d_used_percent"
resetAfterKey = "codex_7d_reset_after_seconds"
resetAtKey = "codex_7d_reset_at"
default:
return nil
}
usedRaw, ok := extra[usedPercentKey]
if !ok {
return nil
}
progress := &UsageProgress{Utilization: parseExtraFloat64(usedRaw)}
if resetAtRaw, ok := extra[resetAtKey]; ok {
if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil {
progress.ResetsAt = &resetAt
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
if progress.RemainingSeconds < 0 {
progress.RemainingSeconds = 0
}
}
}
if progress.ResetsAt == nil {
if resetAfterSeconds := parseExtraInt(extra[resetAfterKey]); resetAfterSeconds > 0 {
base := now
if updatedAtRaw, ok := extra["codex_usage_updated_at"]; ok {
if updatedAt, err := parseTime(fmt.Sprint(updatedAtRaw)); err == nil {
base = updatedAt
}
}
resetAt := base.Add(time.Duration(resetAfterSeconds) * time.Second)
progress.ResetsAt = &resetAt
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
if progress.RemainingSeconds < 0 {
progress.RemainingSeconds = 0
}
}
}
return progress
}
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {
@@ -666,15 +1000,30 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
remaining = 0
}
// 根据状态估算使用率 (百分比形式100 = 100%)
// 优先使用响应头中存储的真实 utilization 值0-1 小数,转为 0-100 百分比)
var utilization float64
switch account.SessionWindowStatus {
case "rejected":
utilization = 100.0
case "allowed_warning":
utilization = 80.0
default:
utilization = 0.0
var found bool
if stored, ok := account.Extra["session_window_utilization"]; ok {
switch v := stored.(type) {
case float64:
utilization = v * 100
found = true
case json.Number:
if f, err := v.Float64(); err == nil {
utilization = f * 100
found = true
}
}
}
// 如果没有存储的 utilization回退到状态估算
if !found {
switch account.SessionWindowStatus {
case "rejected":
utilization = 100.0
case "allowed_warning":
utilization = 80.0
}
}
info.FiveHour = &UsageProgress{

View File

@@ -0,0 +1,68 @@
package service
import (
"net/http"
"testing"
"time"
)
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
t.Parallel()
rateLimitedUntil := time.Now().Add(5 * time.Minute)
now := time.Now()
usage := &UsageInfo{
FiveHour: &UsageProgress{Utilization: 0},
SevenDay: &UsageProgress{Utilization: 0},
}
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
t.Fatal("expected rate-limited account to force codex snapshot refresh")
}
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
}
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
t.Fatal("expected missing 5h snapshot to require refresh")
}
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
if !shouldRefreshOpenAICodexSnapshot(&Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"codex_usage_updated_at": staleAt,
},
}, usage, now) {
t.Fatal("expected stale ws snapshot to trigger refresh")
}
}
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if got := updates["codex_5h_used_percent"]; got != 100.0 {
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
}
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
}

View File

@@ -145,6 +145,9 @@ type CreateGroupInput struct {
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -181,6 +184,9 @@ type UpdateGroupInput struct {
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -909,6 +915,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
DefaultMappedModel: input.DefaultMappedModel,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -1122,6 +1130,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.SupportedModelScopes = *input.SupportedModelScopes
}
// OpenAI Messages 调度配置
if input.AllowMessagesDispatch != nil {
group.AllowMessagesDispatch = *input.AllowMessagesDispatch
}
if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
@@ -1333,6 +1349,10 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
if err != nil {
return nil, 0, err
}
now := time.Now()
for i := range accounts {
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
}
return accounts, result.Total, nil
}
@@ -1468,9 +1488,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
// 保留 quota_used,防止编辑账号时意外重置配额用量
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
input.Extra["quota_used"] = oldQuotaUsed
// 保留配额用量字段,防止编辑账号时意外重置
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
if v, ok := account.Extra[key]; ok {
input.Extra[key] = v
}
}
account.Extra = input.Extra
}
@@ -1701,16 +1723,10 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if err := s.accountRepo.ClearError(ctx, id); err != nil {
return nil, err
}
account.Status = StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
return account, nil
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {

View File

@@ -14,6 +14,11 @@ const (
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
)
const (
AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent
AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup
)
const (
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance

View File

@@ -33,23 +33,25 @@ func NewAnnouncementService(
}
type CreateAnnouncementInput struct {
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
ActorID *int64 // 管理员用户ID
Title string
Content string
Status string
NotifyMode string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
ActorID *int64 // 管理员用户ID
}
type UpdateAnnouncementInput struct {
Title *string
Content *string
Status *string
Targeting *AnnouncementTargeting
StartsAt **time.Time
EndsAt **time.Time
ActorID *int64 // 管理员用户ID
Title *string
Content *string
Status *string
NotifyMode *string
Targeting *AnnouncementTargeting
StartsAt **time.Time
EndsAt **time.Time
ActorID *int64 // 管理员用户ID
}
type UserAnnouncement struct {
@@ -93,6 +95,14 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
return nil, err
}
notifyMode := strings.TrimSpace(input.NotifyMode)
if notifyMode == "" {
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("create announcement: invalid notify_mode")
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
@@ -100,12 +110,13 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
}
a := &Announcement{
Title: title,
Content: content,
Status: status,
Targeting: targeting,
StartsAt: input.StartsAt,
EndsAt: input.EndsAt,
Title: title,
Content: content,
Status: status,
NotifyMode: notifyMode,
Targeting: targeting,
StartsAt: input.StartsAt,
EndsAt: input.EndsAt,
}
if input.ActorID != nil && *input.ActorID > 0 {
a.CreatedBy = input.ActorID
@@ -150,6 +161,14 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
a.Status = status
}
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("update announcement: invalid notify_mode")
}
a.NotifyMode = notifyMode
}
if input.Targeting != nil {
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
if err != nil {
@@ -376,3 +395,12 @@ func isValidAnnouncementStatus(status string) bool {
return false
}
}
func isValidAnnouncementNotifyMode(mode string) bool {
switch mode {
case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup:
return true
default:
return false
}
}

View File

@@ -1384,7 +1384,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 优先检测 thinking block 的 signature 相关错误400并重试一次
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody, maxBytes := s.getLogConfig()
@@ -1517,6 +1517,80 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
// Budget 整流:检测 budget_tokens 约束错误并自动修正重试
if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) {
errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "budget_constraint_error",
Message: errMsg,
Detail: s.getUpstreamErrorDetail(respBody),
})
// 修正 claudeReq 的 thinking 参数adaptive 模式不修正)
if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
// 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针
retryClaudeReq.Thinking = &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: BudgetRectifyBudgetTokens,
}
if retryClaudeReq.MaxTokens < BudgetRectifyMinMaxTokens {
retryClaudeReq.MaxTokens = BudgetRectifyMaxTokens
}
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts)
if txErr == nil {
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: action,
body: retryGeminiBody,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0,
sessionHash: "",
})
if retryErr == nil {
retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
} else {
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
}
}
} else {
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: budget rectifier retry failed: %v", account.ID, retryErr)
}
}
}
}
}
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
@@ -3696,6 +3770,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
cw.Write(finalEvents)
} else if !processor.MessageStartSent() && !cw.Disconnected() {
// 整个流未收到任何可解析的上游数据(全部 SSE 行均无法被 JSON 解析),
// 触发 failover 在同账号重试,避免向客户端发出缺少 message_start 的残缺流
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Claude-Stream] empty stream response (no valid events parsed), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil
}

View File

@@ -998,6 +998,46 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
require.True(t, result.clientDisconnect)
}
// TestHandleClaudeStreamingResponse_EmptyStream
// 验证:上游只返回无法解析的 SSE 行时,触发 UpstreamFailoverError 而不是向客户端发出残缺流
func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 所有行均为无法 JSON 解析的内容ProcessLine 全部返回 nil
fmt.Fprintln(pw, "data: not-valid-json")
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, "data: also-invalid")
fmt.Fprintln(pw, "")
}()
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
// 应当返回 UpstreamFailoverError 而非 nil以便上层触发 failover
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.True(t, failoverErr.RetryableOnSameAccount)
// 客户端不应收到任何 SSE 事件(既无 message_start 也无 message_stop
body := rec.Body.String()
require.NotContains(t, body, "event: message_start")
require.NotContains(t, body, "event: message_stop")
require.NotContains(t, body, "event: message_delta")
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {

View File

@@ -14,6 +14,18 @@ const (
StatusAPIKeyExpired = "expired"
)
// Rate limit window durations
const (
RateLimitWindow5h = 5 * time.Hour
RateLimitWindow1d = 24 * time.Hour
RateLimitWindow7d = 7 * 24 * time.Hour
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
return windowStart != nil && time.Since(*windowStart) >= duration
}
type APIKey struct {
ID int64
UserID int64
@@ -98,6 +110,30 @@ func (k *APIKey) GetDaysUntilExpiry() int {
return int(duration.Hours() / 24)
}
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage5h() float64 {
if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) {
return 0
}
return k.Usage5h
}
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage1d() float64 {
if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) {
return 0
}
return k.Usage1d
}
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage7d() float64 {
if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) {
return 0
}
return k.Usage7d
}
// APIKeyListFilters holds optional filtering parameters for listing API keys.
type APIKeyListFilters struct {
Search string

View File

@@ -65,6 +65,10 @@ type APIKeyAuthGroupSnapshot struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存

View File

@@ -245,6 +245,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
}
}
return snapshot
@@ -302,6 +304,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
}
}
s.compileAPIKeyIPRules(apiKey)

View File

@@ -0,0 +1,245 @@
package service
import (
"testing"
"time"
)
func TestIsWindowExpired(t *testing.T) {
now := time.Now()
tests := []struct {
name string
start *time.Time
duration time.Duration
want bool
}{
{
name: "nil window start",
start: nil,
duration: RateLimitWindow5h,
want: false,
},
{
name: "active window (started 1h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-1 * time.Hour)),
duration: RateLimitWindow5h,
want: false,
},
{
name: "expired window (started 6h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-6 * time.Hour)),
duration: RateLimitWindow5h,
want: true,
},
{
name: "exactly at boundary (started 5h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-5 * time.Hour)),
duration: RateLimitWindow5h,
want: true,
},
{
name: "active 1d window (started 12h ago)",
start: rateLimitTimePtr(now.Add(-12 * time.Hour)),
duration: RateLimitWindow1d,
want: false,
},
{
name: "expired 1d window (started 25h ago)",
start: rateLimitTimePtr(now.Add(-25 * time.Hour)),
duration: RateLimitWindow1d,
want: true,
},
{
name: "active 7d window (started 3d ago)",
start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
duration: RateLimitWindow7d,
want: false,
},
{
name: "expired 7d window (started 8d ago)",
start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
duration: RateLimitWindow7d,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsWindowExpired(tt.start, tt.duration)
if got != tt.want {
t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIKey_EffectiveUsage(t *testing.T) {
now := time.Now()
tests := []struct {
name string
key APIKey
want5h float64
want1d float64
want7d float64
}{
{
name: "all windows active",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "all windows expired",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "nil window starts return raw usage",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: nil,
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "mixed: 5h expired, 1d active, 7d nil",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
Window7dStart: nil,
},
want5h: 0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "zero usage with active windows",
key: APIKey{
Usage5h: 0,
Usage1d: 0,
Usage7d: 0,
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.EffectiveUsage5h(); got != tt.want5h {
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
}
if got := tt.key.EffectiveUsage1d(); got != tt.want1d {
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
}
if got := tt.key.EffectiveUsage7d(); got != tt.want7d {
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
}
})
}
}
func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
now := time.Now()
tests := []struct {
name string
data APIKeyRateLimitData
want5h float64
want1d float64
want7d float64
}{
{
name: "all windows active",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)),
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
},
{
name: "all windows expired",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "nil window starts return raw usage",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: nil,
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.data.EffectiveUsage5h(); got != tt.want5h {
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
}
if got := tt.data.EffectiveUsage1d(); got != tt.want1d {
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
}
if got := tt.data.EffectiveUsage7d(); got != tt.want7d {
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
}
})
}
}
func rateLimitTimePtr(t time.Time) *time.Time {
return &t
}

View File

@@ -86,6 +86,30 @@ type APIKeyRateLimitData struct {
Window7dStart *time.Time
}
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 {
if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) {
return 0
}
return d.Usage5h
}
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 {
if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) {
return 0
}
return d.Usage1d
}
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) {
return 0
}
return d.Usage7d
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)

View File

@@ -565,15 +565,15 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
needsReset := false
// Reset expired windows in-memory for check purposes
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
if IsWindowExpired(w5h, RateLimitWindow5h) {
usage5h = 0
needsReset = true
}
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
if IsWindowExpired(w1d, RateLimitWindow1d) {
usage1d = 0
needsReset = true
}
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
if IsWindowExpired(w7d, RateLimitWindow7d) {
usage7d = 0
needsReset = true
}
@@ -589,12 +589,16 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
if loader, ok := s.apiKeyRateLimitLoader.(interface {
ResetRateLimitWindows(ctx context.Context, id int64) error
}); ok {
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
}
}
}
// Invalidate cache so next request loads fresh data
if s.cache != nil {
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
}
}
}()
}

View File

@@ -43,16 +43,19 @@ type BillingCache interface {
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
InputPricePerToken float64 // 每token输入价格 (USD)
InputPricePerTokenPriority float64 // priority service tier 下每token输价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
}
const (
@@ -61,6 +64,28 @@ const (
openAIGPT54LongContextOutputMultiplier = 1.5
)
func normalizeBillingServiceTier(serviceTier string) string {
return strings.ToLower(strings.TrimSpace(serviceTier))
}
func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool {
if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" {
return false
}
return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0
}
func serviceTierCostMultiplier(serviceTier string) float64 {
switch normalizeBillingServiceTier(serviceTier) {
case "priority":
return 2.0
case "flex":
return 0.5
default:
return 1.0
}
}
// UsageTokens 使用的token数量
type UsageTokens struct {
InputTokens int
@@ -173,30 +198,60 @@ func (s *BillingService) initFallbackPricing() {
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
InputPricePerToken: 1.25e-6, // $1.25 per MTok
OutputPricePerToken: 10e-6, // $10 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.125e-6,
SupportsCacheBreakdown: false,
InputPricePerToken: 1.25e-6, // $1.25 per MTok
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
OutputPricePerToken: 10e-6, // $10 per MTok
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.125e-6,
CacheReadPricePerTokenPriority: 0.25e-6,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
SupportsCacheBreakdown: false,
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
InputPricePerToken: 2.5e-6, // $2.5 per MTok
InputPricePerTokenPriority: 5e-6, // $5 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
OutputPricePerTokenPriority: 30e-6, // $30 per MTok
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok
SupportsCacheBreakdown: false,
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
}
// OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
InputPricePerTokenPriority: 3.5e-6,
OutputPricePerToken: 14e-6,
OutputPricePerTokenPriority: 28e-6,
CacheCreationPricePerToken: 1.75e-6,
CacheReadPricePerToken: 0.175e-6,
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
CacheReadPricePerToken: 0.15e-6,
SupportsCacheBreakdown: false,
InputPricePerToken: 1.5e-6, // $1.5 per MTok
InputPricePerTokenPriority: 3e-6, // $3 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
OutputPricePerTokenPriority: 24e-6, // $24 per MTok
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
CacheReadPricePerToken: 0.15e-6,
CacheReadPricePerTokenPriority: 0.3e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
InputPricePerTokenPriority: 3.5e-6,
OutputPricePerToken: 14e-6,
OutputPricePerTokenPriority: 28e-6,
CacheCreationPricePerToken: 1.75e-6,
CacheReadPricePerToken: 0.175e-6,
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
@@ -241,6 +296,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
switch normalized {
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2":
return s.fallbackPrices["gpt-5.2"]
case "gpt-5.2-codex":
return s.fallbackPrices["gpt-5.2-codex"]
case "gpt-5.3-codex":
return s.fallbackPrices["gpt-5.3-codex"]
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
@@ -269,16 +328,19 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
InputPricePerToken: litellmPricing.InputCostPerToken,
InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
}), nil
}
}
@@ -295,6 +357,10 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
}
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
pricing, err := s.GetModelPricing(model)
if err != nil {
return nil, err
@@ -303,6 +369,21 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
breakdown := &CostBreakdown{}
inputPricePerToken := pricing.InputPricePerToken
outputPricePerToken := pricing.OutputPricePerToken
cacheReadPricePerToken := pricing.CacheReadPricePerToken
tierMultiplier := 1.0
if usePriorityServiceTierPricing(serviceTier, pricing) {
if pricing.InputPricePerTokenPriority > 0 {
inputPricePerToken = pricing.InputPricePerTokenPriority
}
if pricing.OutputPricePerTokenPriority > 0 {
outputPricePerToken = pricing.OutputPricePerTokenPriority
}
if pricing.CacheReadPricePerTokenPriority > 0 {
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
}
} else {
tierMultiplier = serviceTierCostMultiplier(serviceTier)
}
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
@@ -329,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier
}
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +

View File

@@ -522,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) {
require.False(t, math.IsNaN(cost.TotalCost))
require.False(t, math.IsInf(cost.TotalCost, 0))
}
func TestServiceTierCostMultiplier(t *testing.T) {
require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12)
require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12)
require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12)
require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12)
require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12)
}
func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20}
baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0)
require.NoError(t, err)
priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
}
func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0)
require.NoError(t, err)
flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10)
}
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8}
baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
}
func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) {
pricingSvc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"gpt-5.4": {
InputCostPerToken: 2.5e-6,
InputCostPerTokenPriority: 5e-6,
OutputCostPerToken: 15e-6,
OutputCostPerTokenPriority: 30e-6,
CacheCreationInputTokenCost: 2.5e-6,
CacheReadInputTokenCost: 0.25e-6,
CacheReadInputTokenCostPriority: 0.5e-6,
LongContextInputTokenThreshold: 272000,
LongContextInputCostMultiplier: 2.0,
LongContextOutputCostMultiplier: 1.5,
},
},
}
svc := NewBillingService(&config.Config{}, pricingSvc)
pricing, err := svc.GetModelPricing("gpt-5.4")
require.NoError(t, err)
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.Equal(t, 272000, pricing.LongContextInputThreshold)
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) {
svc := newTestBillingService()
gpt52, err := svc.GetModelPricing("gpt-5.2")
require.NoError(t, err)
require.NotNil(t, gpt52)
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
require.NoError(t, err)
require.NotNil(t, gpt52Codex)
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
}
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) {
svc := NewBillingService(&config.Config{}, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"custom-no-priority": {
InputCostPerToken: 1e-6,
OutputCostPerToken: 2e-6,
CacheCreationInputTokenCost: 0.5e-6,
CacheReadInputTokenCost: 0.25e-6,
},
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0)
require.NoError(t, err)
priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
}
func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) {
svc := newTestBillingService()
gpt52, err := svc.GetModelPricing("gpt-5.2")
require.NoError(t, err)
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12)
require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12)
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
require.NoError(t, err)
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12)
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
}
func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) {
svc := NewBillingService(&config.Config{}, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"dynamic-tier-model": {
InputCostPerToken: 1e-6,
InputCostPerTokenPriority: 2e-6,
OutputCostPerToken: 3e-6,
OutputCostPerTokenPriority: 6e-6,
CacheCreationInputTokenCost: 4e-6,
CacheCreationInputTokenCostAbove1hr: 5e-6,
CacheReadInputTokenCost: 7e-7,
CacheReadInputTokenCostPriority: 8e-7,
LongContextInputTokenThreshold: 999,
LongContextInputCostMultiplier: 1.5,
LongContextOutputCostMultiplier: 1.25,
},
},
})
pricing, err := svc.GetModelPricing("dynamic-tier-model")
require.NoError(t, err)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
require.True(t, pricing.SupportsCacheBreakdown)
require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.Equal(t, 999, pricing.LongContextInputThreshold)
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
}

View File

@@ -175,6 +175,13 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
// =========================
// Request Rectifier (请求整流器)
// =========================
// SettingKeyRectifierSettings stores JSON config for rectifier settings (thinking signature + budget).
SettingKeyRectifierSettings = "rectifier_settings"
// =========================
// Sora S3 存储配置
// =========================

View File

@@ -177,6 +177,36 @@ func TestCheckErrorPolicy(t *testing.T) {
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
{
name: "pool_mode_custom_error_codes_hit_returns_matched",
account: &Account{
ID: 7,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(401), float64(403)},
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyMatched,
},
{
name: "pool_mode_without_custom_error_codes_returns_skipped",
account: &Account{
ID: 8,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicySkipped,
},
}
for _, tt := range tests {
@@ -190,6 +220,48 @@ func TestCheckErrorPolicy(t *testing.T) {
}
}
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 30,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
}
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.False(t, shouldDisable)
require.Equal(t, 0, repo.setErrCalls)
require.Equal(t, 0, repo.tempCalls)
})
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 31,
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(401)},
},
}
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrCalls)
require.Equal(t, 0, repo.tempCalls)
})
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------

View File

@@ -86,10 +86,10 @@ func TestStripBetaTokens(t *testing.T) {
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "DroppedBetas removes both context-1m and fast-mode",
name: "DroppedBetas removes fast-mode only",
header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
tokens: claude.DroppedBetas,
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
want: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
},
}
@@ -117,21 +117,21 @@ func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
drop := droppedBetaSet()
got := mergeAnthropicBetaDropping(required, incoming, drop)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
require.NotContains(t, got, "context-1m-2025-08-07")
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07,foo-beta", got)
require.Contains(t, got, "context-1m-2025-08-07")
require.NotContains(t, got, "fast-mode-2026-02-01")
}
func TestDroppedBetaSet(t *testing.T) {
// Base set contains DroppedBetas
base := droppedBetaSet()
require.Contains(t, base, claude.BetaContext1M)
require.NotContains(t, base, claude.BetaContext1M)
require.Contains(t, base, claude.BetaFastMode)
require.Len(t, base, len(claude.DroppedBetas))
// With extra tokens
extended := droppedBetaSet(claude.BetaClaudeCode)
require.Contains(t, extended, claude.BetaContext1M)
require.NotContains(t, extended, claude.BetaContext1M)
require.Contains(t, extended, claude.BetaFastMode)
require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1)
@@ -148,6 +148,32 @@ func TestBuildBetaTokenSet(t *testing.T) {
require.Empty(t, empty)
}
func TestContainsBetaToken(t *testing.T) {
tests := []struct {
name string
header string
token string
want bool
}{
{"present in middle", "oauth-2025-04-20,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
{"present at start", "fast-mode-2026-02-01,oauth-2025-04-20", "fast-mode-2026-02-01", true},
{"present at end", "oauth-2025-04-20,fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
{"only token", "fast-mode-2026-02-01", "fast-mode-2026-02-01", true},
{"not present", "oauth-2025-04-20,interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", false},
{"with spaces", "oauth-2025-04-20, fast-mode-2026-02-01 , interleaved-thinking-2025-05-14", "fast-mode-2026-02-01", true},
{"empty header", "", "fast-mode-2026-02-01", false},
{"empty token", "fast-mode-2026-02-01", "", false},
{"partial match", "fast-mode-2026-02-01-extra", "fast-mode-2026-02-01", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := containsBetaToken(tt.header, tt.token)
require.Equal(t, tt.want, got)
})
}
}
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
got := stripBetaTokensWithSet(header, map[string]struct{}{})

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"strings"
"unsafe"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -258,6 +259,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
if !hasEmptyContent && !containsThinkingBlocks {
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
out = removeThinkingDependentContextStrategies(out)
return out
}
return body
@@ -395,6 +397,10 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
} else {
return body
}
// Removing "thinking" makes any context_management strategy that requires it invalid
// (e.g. clear_thinking_20251015). Strip those entries so the retry request does not
// receive a 400 "strategy requires thinking to be enabled or adaptive".
out = removeThinkingDependentContextStrategies(out)
}
if modified {
msgsBytes, err := json.Marshal(messages)
@@ -409,6 +415,49 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return out
}
// removeThinkingDependentContextStrategies 从 context_management.edits 中移除
// 需要 thinking 启用的策略(如 clear_thinking_20251015
// 当顶层 "thinking" 字段被禁用时必须调用,否则上游会返回
// "strategy requires thinking to be enabled or adaptive"。
func removeThinkingDependentContextStrategies(body []byte) []byte {
jsonStr := *(*string)(unsafe.Pointer(&body))
editsRes := gjson.Get(jsonStr, "context_management.edits")
if !editsRes.Exists() || !editsRes.IsArray() {
return body
}
var filtered []json.RawMessage
hasRemoved := false
editsRes.ForEach(func(_, v gjson.Result) bool {
if v.Get("type").String() == "clear_thinking_20251015" {
hasRemoved = true
return true
}
filtered = append(filtered, json.RawMessage(v.Raw))
return true
})
if !hasRemoved {
return body
}
if len(filtered) == 0 {
if b, err := sjson.DeleteBytes(body, "context_management.edits"); err == nil {
return b
}
return body
}
filteredBytes, err := json.Marshal(filtered)
if err != nil {
return body
}
if b, err := sjson.SetRawBytes(body, "context_management.edits", filteredBytes); err == nil {
return b
}
return body
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//
@@ -444,6 +493,28 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
// Remove context_management strategies that require thinking to be enabled
// (e.g. clear_thinking_20251015), otherwise upstream returns 400.
if cm, ok := req["context_management"].(map[string]any); ok {
if edits, ok := cm["edits"].([]any); ok {
filtered := make([]any, 0, len(edits))
for _, edit := range edits {
if editMap, ok := edit.(map[string]any); ok {
if editMap["type"] == "clear_thinking_20251015" {
continue
}
}
filtered = append(filtered, edit)
}
if len(filtered) != len(edits) {
if len(filtered) == 0 {
delete(cm, "edits")
} else {
cm["edits"] = filtered
}
}
}
}
}
messages, ok := req["messages"].([]any)
@@ -675,3 +746,90 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
}
return newBody
}
// =========================
// Thinking Budget Rectifier
// =========================
const (
// BudgetRectifyBudgetTokens is the budget_tokens value to set when rectifying.
BudgetRectifyBudgetTokens = 32000
// BudgetRectifyMaxTokens is the max_tokens value to set when rectifying.
BudgetRectifyMaxTokens = 64000
// BudgetRectifyMinMaxTokens is the minimum max_tokens that must exceed budget_tokens.
BudgetRectifyMinMaxTokens = 32001
)
// isThinkingBudgetConstraintError detects whether an upstream error message indicates
// a budget_tokens constraint violation (e.g. "budget_tokens >= 1024").
// Matches three conditions (all must be true):
// 1. Contains "budget_tokens" or "budget tokens"
// 2. Contains "thinking"
// 3. Contains ">= 1024" or "greater than or equal to 1024" or ("1024" + "input should be")
func isThinkingBudgetConstraintError(errMsg string) bool {
m := strings.ToLower(errMsg)
// Condition 1: budget_tokens or budget tokens
hasBudget := strings.Contains(m, "budget_tokens") || strings.Contains(m, "budget tokens")
if !hasBudget {
return false
}
// Condition 2: thinking
if !strings.Contains(m, "thinking") {
return false
}
// Condition 3: constraint indicator
if strings.Contains(m, ">= 1024") || strings.Contains(m, "greater than or equal to 1024") {
return true
}
if strings.Contains(m, "1024") && strings.Contains(m, "input should be") {
return true
}
return false
}
// RectifyThinkingBudget modifies the request body to fix budget_tokens constraint errors.
// It sets thinking.budget_tokens = 32000, thinking.type = "enabled" (unless adaptive),
// and ensures max_tokens >= 32001.
// Returns (modified body, true) if changes were applied, or (original body, false) if not.
func RectifyThinkingBudget(body []byte) ([]byte, bool) {
// If thinking type is "adaptive", skip rectification entirely
thinkingType := gjson.GetBytes(body, "thinking.type").String()
if thinkingType == "adaptive" {
return body, false
}
modified := body
changed := false
// Set thinking.type = "enabled"
if thinkingType != "enabled" {
if result, err := sjson.SetBytes(modified, "thinking.type", "enabled"); err == nil {
modified = result
changed = true
}
}
// Set thinking.budget_tokens = 32000
currentBudget := gjson.GetBytes(modified, "thinking.budget_tokens").Int()
if currentBudget != BudgetRectifyBudgetTokens {
if result, err := sjson.SetBytes(modified, "thinking.budget_tokens", BudgetRectifyBudgetTokens); err == nil {
modified = result
changed = true
}
}
// Ensure max_tokens >= BudgetRectifyMinMaxTokens
maxTokens := gjson.GetBytes(modified, "max_tokens").Int()
if maxTokens < int64(BudgetRectifyMinMaxTokens) {
if result, err := sjson.SetBytes(modified, "max_tokens", BudgetRectifyMaxTokens); err == nil {
modified = result
changed = true
}
}
return modified, changed
}

View File

@@ -439,6 +439,210 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
require.Contains(t, content1["text"], "tool_result")
}
// ============ Group 6b: context_management.edits 清理测试 ============
// removeThinkingDependentContextStrategies — 边界用例
func TestRemoveThinkingDependentContextStrategies_NoContextManagement(t *testing.T) {
input := []byte(`{"thinking":{"type":"enabled"},"messages":[]}`)
out := removeThinkingDependentContextStrategies(input)
require.Equal(t, input, out, "无 context_management 字段时应原样返回")
}
func TestRemoveThinkingDependentContextStrategies_EmptyEdits(t *testing.T) {
input := []byte(`{"context_management":{"edits":[]},"messages":[]}`)
out := removeThinkingDependentContextStrategies(input)
require.Equal(t, input, out, "edits 为空数组时应原样返回")
}
func TestRemoveThinkingDependentContextStrategies_NoClearThinkingEntry(t *testing.T) {
input := []byte(`{"context_management":{"edits":[{"type":"other_strategy"}]},"messages":[]}`)
out := removeThinkingDependentContextStrategies(input)
require.Equal(t, input, out, "edits 中无 clear_thinking_20251015 时应原样返回")
}
func TestRemoveThinkingDependentContextStrategies_RemovesSingleEntry(t *testing.T) {
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
out := removeThinkingDependentContextStrategies(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
_, hasEdits := cm["edits"]
require.False(t, hasEdits, "所有 edits 均为 clear_thinking_20251015 时应删除 edits 键")
}
func TestRemoveThinkingDependentContextStrategies_MixedEntries(t *testing.T) {
input := []byte(`{"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_strategy","param":1}]},"messages":[]}`)
out := removeThinkingDependentContextStrategies(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
edits, ok := cm["edits"].([]any)
require.True(t, ok)
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015保留其他条目")
edit0, ok := edits[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "other_strategy", edit0["type"])
}
// FilterThinkingBlocksForRetry — 包含 context_management 的场景
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_FastPath(t *testing.T) {
// 快速路径messages 中无 thinking 块,仅有顶层 thinking 字段
// 这条路径曾因提前 return 跳过 removeThinkingDependentContextStrategies 而存在 bug
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hello"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking, "顶层 thinking 应被移除")
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
_, hasEdits := cm["edits"]
require.False(t, hasEdits, "fast path 下 clear_thinking_20251015 应被移除edits 键应被删除")
}
func TestFilterThinkingBlocksForRetry_RemovesClearThinkingStrategy_WithThinkingBlocks(t *testing.T) {
// 完整路径messages 中有 thinking 块(非 fast path
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"keep_this"}]},
"messages":[
{"role":"assistant","content":[
{"type":"thinking","thinking":"some thought","signature":"sig"},
{"type":"text","text":"Answer"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking, "顶层 thinking 应被移除")
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
edits, ok := cm["edits"].([]any)
require.True(t, ok)
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015保留 keep_this")
edit0, ok := edits[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "keep_this", edit0["type"])
}
func TestFilterThinkingBlocksForRetry_NoContextManagement_Unaffected(t *testing.T) {
// 无 context_management 时不应报错,且 thinking 正常被移除
input := []byte(`{
"thinking":{"type":"enabled"},
"messages":[{"role":"user","content":[{"type":"text","text":"Hi"}]}]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
_, hasCM := req["context_management"]
require.False(t, hasCM)
}
// FilterSignatureSensitiveBlocksForRetry — 包含 context_management 的场景
func TestFilterSignatureSensitiveBlocksForRetry_RemovesClearThinkingStrategy(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
"messages":[
{"role":"assistant","content":[
{"type":"thinking","thinking":"thought","signature":"sig"}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking, "顶层 thinking 应被移除")
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
if rawEdits, hasEdits := cm["edits"]; hasEdits {
edits, ok := rawEdits.([]any)
require.True(t, ok)
for _, e := range edits {
em, ok := e.(map[string]any)
require.True(t, ok)
require.NotEqual(t, "clear_thinking_20251015", em["type"], "clear_thinking_20251015 应被移除")
}
}
}
func TestFilterSignatureSensitiveBlocksForRetry_PreservesNonThinkingStrategies(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled"},
"context_management":{"edits":[{"type":"clear_thinking_20251015"},{"type":"other_edit"}]},
"messages":[
{"role":"assistant","content":[
{"type":"thinking","thinking":"t","signature":"s"}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
edits, ok := cm["edits"].([]any)
require.True(t, ok)
require.Len(t, edits, 1, "仅移除 clear_thinking_20251015保留 other_edit")
edit0, ok := edits[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "other_edit", edit0["type"])
}
func TestFilterSignatureSensitiveBlocksForRetry_NoThinkingField_ContextManagementUntouched(t *testing.T) {
// 没有顶层 thinking 字段时context_management 不应被修改
input := []byte(`{
"context_management":{"edits":[{"type":"clear_thinking_20251015"}]},
"messages":[
{"role":"assistant","content":[
{"type":"thinking","thinking":"t","signature":"s"}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
cm, ok := req["context_management"].(map[string]any)
require.True(t, ok)
edits, ok := cm["edits"].([]any)
require.True(t, ok)
require.Len(t, edits, 1, "无顶层 thinking 时 context_management 不应被修改")
}
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
// Task 7.1 — 类型校验边界测试

View File

@@ -41,7 +41,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 40 * 1024 * 1024
defaultMaxLineSize = 500 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
@@ -501,33 +501,35 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
settingService *SettingService
responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
}
// NewGatewayService creates a new GatewayService
@@ -552,6 +554,7 @@ func NewGatewayService(
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
digestStore *DigestSessionStore,
settingService *SettingService,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
@@ -578,10 +581,18 @@ func NewGatewayService(
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
settingService: settingService,
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
svc.userGroupRateCache,
userGroupRateTTL,
&svc.userGroupRateSF,
"service.gateway",
)
svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return svc
@@ -986,6 +997,11 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
}
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
func GenerateSessionUUID(seed string) string {
return generateSessionUUID(seed)
}
func generateSessionUUID(seed string) string {
if seed == "" {
return uuid.NewString()
@@ -4057,7 +4073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@@ -4174,7 +4190,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
// 不是签名错误(或整流器已关闭),继续检查 budget 约束
errMsg := extractUpstreamErrorMessage(respBody)
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "budget_constraint_error",
Message: errMsg,
Detail: func() string {
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
}
return ""
}(),
})
rectifiedBody, applied := RectifyThinkingBudget(body)
if applied && time.Since(retryStart) < maxRetryElapsed {
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil {
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
resp = budgetRetryResp
break
}
if budgetRetryResp != nil && budgetRetryResp.Body != nil {
_ = budgetRetryResp.Body.Close()
}
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr)
} else {
logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr)
}
}
}
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
@@ -4266,7 +4320,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -4296,7 +4354,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
@@ -4531,7 +4593,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -4561,7 +4627,11 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
if resp.StatusCode >= 400 {
@@ -5276,6 +5346,19 @@ func droppedBetaSet(extra ...string) map[string]struct{} {
return m
}
// containsBetaToken checks if a comma-separated header value contains the given token.
func containsBetaToken(header, token string) bool {
if header == "" || token == "" {
return false
}
for _, p := range strings.Split(header, ",") {
if strings.TrimSpace(p) == token {
return true
}
}
return false
}
func buildBetaTokenSet(tokens []string) map[string]struct{} {
m := make(map[string]struct{}, len(tokens))
for _, t := range tokens {
@@ -5417,6 +5500,11 @@ func extractUpstreamErrorMessage(body []byte) string {
return m
}
// ChatGPT 内部 API 风格:{"detail":"..."}
if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" {
return d
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
@@ -6332,63 +6420,20 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
}
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if s == nil || userID <= 0 || groupID <= 0 {
if s == nil {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
resolver := s.userGroupRateResolver
if resolver == nil {
resolver = newUserGroupRateResolver(
s.userGroupRateRepo,
s.userGroupRateCache,
resolveUserGroupRateCacheTTL(s.cfg),
&s.userGroupRateSF,
"service.gateway",
)
}
if s.userGroupRateRepo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if s.userGroupRateCache != nil {
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier)
}
// RecordUsageInput 记录使用量的输入参数
@@ -6463,7 +6508,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
@@ -6954,7 +6999,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 检测 thinking block 签名错误400并重试一次过滤 thinking blocks
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)

View File

@@ -57,6 +57,10 @@ type Group struct {
// 分组排序
SortOrder int
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -319,7 +319,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
@@ -687,16 +687,20 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
}
return &AccountSelectionResult{
Account: candidate.account,
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
@@ -705,16 +709,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
AccountID: candidate.account.ID,
MaxConcurrency: candidate.account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
AccountID: fresh.ID,
MaxConcurrency: fresh.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {

View File

@@ -12,6 +12,78 @@ import (
"github.com/stretchr/testify/require"
)
type openAISnapshotCacheStub struct {
SchedulerCache
snapshotAccounts []*Account
accountsByID map[int64]*Account
}
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
}
out := make([]*Account, 0, len(s.snapshotAccounts))
for _, account := range s.snapshotAccounts {
if account == nil {
continue
}
cloned := *account
out = append(out, &cloned)
}
return out, true, nil
}
func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.accountsByID == nil {
return nil, nil
}
account := s.accountsByID[accountID]
if account == nil {
return nil, nil
}
cloned := *account
return &cloned, nil
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
staleSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(31002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10102)
rateLimitedUntil := time.Now().Add(30 * time.Minute)
stalePrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}
staleSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshPrimary := &Account{ID: 32001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(32002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(9)

View File

@@ -1,13 +1,9 @@
package service
import (
_ "embed"
"strings"
)
//go:embed prompts/codex_cli_instructions.md
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-none": "gpt-5.4",
@@ -77,7 +73,7 @@ type codexTransformResult struct {
PromptCacheKey string
}
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody)
@@ -95,15 +91,26 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
result.NormalizedModel = normalizedModel
}
// OAuth 走 ChatGPT internal API 时store 必须为 false显式 true 也会强制覆盖。
// 避免上游返回 "Store must be set to false"。
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
if isCompact {
if _, ok := reqBody["store"]; ok {
delete(reqBody, "store")
result.Modified = true
}
if _, ok := reqBody["stream"]; ok {
delete(reqBody, "stream")
result.Modified = true
}
} else {
// OAuth 走 ChatGPT internal API 时store 必须为 false显式 true 也会强制覆盖。
// 避免上游返回 "Store must be set to false"。
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
}
}
// Strip parameters unsupported by codex models via the Responses API.
@@ -219,72 +226,13 @@ func getNormalizedCodexModel(modelID string) string {
return ""
}
func getOpenCodeCodexHeader() string {
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions避免读写缓存与外网依赖。
return getCodexCLIInstructions()
}
func getCodexCLIInstructions() string {
return codexCLIInstructions
}
func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
func GetCodexCLIInstructions() string {
return getCodexCLIInstructions()
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions使用内置 Codex CLI 指令)
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
// applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if isCodexCLI {
return applyCodexCLIInstructions(reqBody)
}
return applyOpenCodeInstructions(reqBody)
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
func applyCodexCLIInstructions(reqBody map[string]any) bool {
if !isInstructionsEmpty(reqBody) {
return false // 已有有效 instructions不修改
return false
}
instructions := strings.TrimSpace(getCodexCLIInstructions())
if instructions != "" {
reqBody["instructions"] = instructions
return true
}
return false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
// 优先使用内置 Codex CLI 指令覆盖
func applyOpenCodeInstructions(reqBody map[string]any) bool {
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions = strings.TrimSpace(existingInstructions)
if instructions != "" {
if existingInstructions != instructions {
reqBody["instructions"] = instructions
return true
}
} else if existingInstructions == "" {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions != "" {
reqBody["instructions"] = codexInstructions
return true
}
}
return false
reqBody["instructions"] = "You are a helpful coding assistant."
return true
}
// isInstructionsEmpty 检查 instructions 字段是否为空

View File

@@ -18,7 +18,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false)
applyCodexOAuthTransform(reqBody, false, false)
// 未显式设置 store=true默认为 false。
store, ok := reqBody["store"].(bool)
@@ -53,7 +53,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false)
applyCodexOAuthTransform(reqBody, false, false)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
@@ -72,13 +72,29 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false)
applyCodexOAuthTransform(reqBody, false, false)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}
func TestApplyCodexOAuthTransform_CompactForcesNonStreaming(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.1-codex",
"store": true,
"stream": true,
}
result := applyCodexOAuthTransform(reqBody, true, true)
_, hasStore := reqBody["store"]
require.False(t, hasStore)
_, hasStream := reqBody["stream"]
require.False(t, hasStream)
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false并移除 input 中的 id。
@@ -89,7 +105,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(
},
}
applyCodexOAuthTransform(reqBody, false)
applyCodexOAuthTransform(reqBody, false, false)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
@@ -138,7 +154,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
},
}
applyCodexOAuthTransform(reqBody, false)
applyCodexOAuthTransform(reqBody, false, false)
tools, ok := reqBody["tools"].([]any)
require.True(t, ok)
@@ -158,7 +174,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
"input": []any{},
}
applyCodexOAuthTransform(reqBody, false)
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
@@ -193,7 +209,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *test
"instructions": "existing instructions",
}
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
@@ -210,7 +226,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
// 没有 instructions 字段
}
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
result := applyCodexOAuthTransform(reqBody, true, false) // isCodexCLI=true
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
@@ -218,20 +234,19 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
func TestApplyCodexOAuthTransform_NonCodexCLI_PreservesExistingInstructions(t *testing.T) {
// 非 Codex CLI 场景:已有 instructions 时保留客户端的值,不再覆盖
reqBody := map[string]any{
"model": "gpt-5.1",
"instructions": "old instructions",
}
result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false
applyCodexOAuthTransform(reqBody, false, false) // isCodexCLI=false
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotEqual(t, "old instructions", instructions)
require.True(t, result.Modified)
require.Equal(t, "old instructions", instructions)
}
func TestIsInstructionsEmpty(t *testing.T) {

View File

@@ -0,0 +1,574 @@
package service
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsAnthropic accepts an Anthropic Messages request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Anthropic Messages format. This enables Claude Code
// clients to access OpenAI models through the standard /v1/messages endpoint.
func (s *OpenAIGatewayService) ForwardAsAnthropic(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Anthropic request
var anthropicReq apicompat.AnthropicRequest
if err := json.Unmarshal(body, &anthropicReq); err != nil {
return nil, fmt.Errorf("parse anthropic request: %w", err)
}
originalModel := anthropicReq.Model
clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
if err != nil {
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
}
// Upstream always uses streaming (upstream may not support sync mode).
// The client's original preference determines the response format.
responsesReq.Stream = true
isStream := true
// 2b. Handle BetaFastMode → service_tier: "priority"
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
responsesReq.ServiceTier = "priority"
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
// 分组级降级:账号未映射时使用分组默认映射模型
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel
logger.L().Debug("openai messages: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", isStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
// OAuth codex transform forces stream=true upstream, so always use
// the streaming response handler regardless of what the client asked.
isStream = true
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// Override session_id with a deterministic UUID derived from the sticky
// session key (buildUpstreamRequest may have set it to the raw value).
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// Non-failover error: return Anthropic-formatted error to client
return s.handleAnthropicErrorResponse(resp, c, account)
}
// 9. Handle normal response
// Upstream is always streaming; choose response format based on client preference.
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
} else {
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleAnthropicErrorResponse reads an upstream error and returns it in
// Anthropic error format.
func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
// the upstream streaming response, finds the terminal event (response.completed
// / response.incomplete / response.failed), converts the complete response to
// Anthropic Messages JSON format, and writes it to the client.
// This is used when the client requested stream=false but the upstream is always
// streaming.
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, anthropicResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
// pattern to send Anthropic ping events during periods of upstream silence,
// preventing proxy/client timeout disconnections.
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToAnthropicState()
state.Model = originalModel
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
// resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
// processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
// Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
logger.L().Warn("openai messages stream: failed to marshal event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai messages stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(events) > 0 {
c.Writer.Flush()
}
return false
}
// finalizeStream sends any remaining Anthropic events and returns the result.
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
}
return resultWithUsage(), nil
}
// handleScanErr logs scanner errors if meaningful.
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// ── With keepalive: goroutine + channel + select ──
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
// Upstream closed
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send Anthropic-format ping event
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
// Client disconnected
logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeAnthropicError writes an error response in Anthropic Messages API format.
func writeAnthropicError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -0,0 +1,558 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIRecordUsageLogRepoStub struct {
UsageLogRepository
inserted bool
err error
calls int
lastLog *UsageLog
}
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++
s.lastLog = log
return s.inserted, s.err
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
deductCalls int
deductErr error
lastAmount float64
}
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++
s.lastAmount = amount
return s.deductErr
}
type openAIRecordUsageSubRepoStub struct {
UserSubscriptionRepository
incrementCalls int
incrementErr error
}
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++
return s.incrementErr
}
type openAIRecordUsageAPIKeyQuotaStub struct {
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++
s.lastAmount = cost
return s.err
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++
s.lastAmount = cost
return s.err
}
type openAIUserGroupRateRepoStub struct {
UserGroupRateRepository
rate *float64
err error
calls int
}
func (s *openAIUserGroupRateRepoStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
s.calls++
if s.err != nil {
return nil, s.err
}
return s.rate, nil
}
func i64p(v int64) *int64 {
return &v
}
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return &OpenAIGatewayService{
usageLogRepo: usageRepo,
userRepo: userRepo,
userSubRepo: subRepo,
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
}
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
t.Helper()
cost, err := svc.billingService.CalculateCost(model, UsageTokens{
InputTokens: max(usage.InputTokens-usage.CacheReadInputTokens, 0),
OutputTokens: usage.OutputTokens,
CacheCreationTokens: usage.CacheCreationInputTokens,
CacheReadTokens: usage.CacheReadInputTokens,
}, multiplier)
require.NoError(t, err)
return cost
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) {
groupID := int64(11)
groupRate := 1.4
userRate := 1.8
usage := OpenAIUsage{InputTokens: 15, OutputTokens: 4, CacheReadInputTokens: 3}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{rate: &userRate}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_user_group_rate",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1001,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2001},
Account: &Account{ID: 3001},
})
require.NoError(t, err)
require.Equal(t, 1, rateRepo.calls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, userRate, usageRepo.lastLog.RateMultiplier)
require.Equal(t, 12, usageRepo.lastLog.InputTokens)
require.Equal(t, 3, usageRepo.lastLog.CacheReadTokens)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, userRate)
require.InDelta(t, expected.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
groupID := int64(12)
groupRate := 1.6
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 5, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{err: errors.New("db unavailable")}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_group_default_on_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1002,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2002},
Account: &Account{ID: 3002},
})
require.NoError(t, err)
require.Equal(t, 1, rateRepo.calls)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, groupRate)
require.InDelta(t, expected.ActualCost, userRepo.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolverMissing(t *testing.T) {
groupID := int64(13)
groupRate := 1.25
usage := OpenAIUsage{InputTokens: 9, OutputTokens: 4, CacheReadInputTokens: 1}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.userGroupRateResolver = nil
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_group_default_nil_resolver",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1003,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: groupRate,
},
},
User: &User{ID: 2003},
Account: &Account{ID: 3003},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, groupRate, usageRepo.lastLog.RateMultiplier)
}
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1004},
User: &User{ID: 2004},
Account: &Account{ID: 3004},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_quota_update",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1005,
Quota: 100,
},
User: &User{ID: 2005},
Account: &Account{ID: 3005},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.Equal(t, 0, quotaSvc.rateLimitCalls)
expected := expectedOpenAICost(t, svc, "gpt-5.1", usage, 1.1)
require.InDelta(t, expected.ActualCost, quotaSvc.lastAmount, 1e-12)
}
func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_clamp_actual_input",
Usage: OpenAIUsage{
InputTokens: 2,
OutputTokens: 1,
CacheReadInputTokens: 5,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1006},
User: &User{ID: 2006},
Account: &Account{ID: 3006},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, 0, usageRepo.lastLog.InputTokens)
}
func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_gpt54_long_context",
Usage: OpenAIUsage{
InputTokens: 300000,
OutputTokens: 2000,
},
Model: "gpt-5.4-2026-03-05",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1014},
User: &User{ID: 2014},
Account: &Account{ID: 3014},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
expectedInput := 300000 * 2.5e-6 * 2.0
expectedOutput := 2000 * 15e-6 * 1.5
require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10)
require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10)
require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10)
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
serviceTier := "priority"
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_service_tier_priority",
ServiceTier: &serviceTier,
Usage: usage,
Model: "gpt-5.4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1015},
User: &User{ID: 2015},
Account: &Account{ID: 3015},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0)
require.NoError(t, calcErr)
require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10)
}
func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
serviceTier := "flex"
usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_service_tier_flex",
ServiceTier: &serviceTier,
Usage: usage,
Model: "gpt-5.4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1016},
User: &User{ID: 2016},
Account: &Account{ID: 3016},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0)
require.NoError(t, calcErr)
require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10)
}
func TestNormalizeOpenAIServiceTier(t *testing.T) {
t.Run("fast maps to priority", func(t *testing.T) {
got := normalizeOpenAIServiceTier(" fast ")
require.NotNil(t, got)
require.Equal(t, "priority", *got)
})
t.Run("default ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("default"))
})
t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
})
}
func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil))
}
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
serviceTier := "priority"
reasoning := "high"
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_billing_model_override",
BillingModel: "gpt-5.1-codex",
Model: "gpt-5.1",
ServiceTier: &serviceTier,
ReasoningEffort: &reasoning,
Usage: OpenAIUsage{
InputTokens: 20,
OutputTokens: 10,
},
Duration: 2 * time.Second,
FirstTokenMs: func() *int { v := 120; return &v }(),
},
APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}},
User: &User{ID: 20},
Account: &Account{ID: 30},
UserAgent: "codex-cli/1.0",
IPAddress: "127.0.0.1",
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort)
require.NotNil(t, usageRepo.lastLog.UserAgent)
require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent)
require.NotNil(t, usageRepo.lastLog.IPAddress)
require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress)
require.NotNil(t, usageRepo.lastLog.GroupID)
require.Equal(t, int64(11), *usageRepo.lastLog.GroupID)
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
subscription := &UserSubscription{ID: 99}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_subscription_billing",
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
User: &User{ID: 200},
Account: &Account{ID: 300},
Subscription: subscription,
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType)
require.NotNil(t, usageRepo.lastLog.SubscriptionID)
require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID)
require.Equal(t, 1, subRepo.incrementCalls)
require.Equal(t, 0, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc.cfg.RunMode = config.RunModeSimple
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_simple_mode",
Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1000},
User: &User{ID: 2000},
Account: &Account{ID: 3000},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}

Some files were not shown because too many files have changed in this diff Show More